Spaces:
Sleeping
Sleeping
Adding vllm package
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .buildkite/run-benchmarks.sh +63 -0
- .buildkite/test-pipeline.yaml +51 -0
- .buildkite/test-template.j2 +54 -0
- .dockerignore +1 -0
- .github/workflows/publish.yml +102 -0
- .github/workflows/ruff.yml +31 -0
- .github/workflows/scripts/build.sh +20 -0
- .github/workflows/scripts/create_release.js +20 -0
- .github/workflows/scripts/cuda-install.sh +23 -0
- .github/workflows/scripts/env.sh +56 -0
- .github/workflows/scripts/pytorch-install.sh +15 -0
- .github/workflows/yapf.yml +31 -0
- .gitignore +186 -0
- .readthedocs.yaml +21 -0
- CONTRIBUTING.md +77 -0
- Dockerfile +1 -1
- Dockerfile.rocm +88 -0
- LICENSE +201 -0
- MANIFEST.in +4 -0
- README.md +110 -7
- benchmarks/README.md +8 -0
- benchmarks/benchmark_latency.py +139 -0
- benchmarks/benchmark_serving.py +249 -0
- benchmarks/benchmark_throughput.py +328 -0
- benchmarks/kernels/benchmark_paged_attention.py +196 -0
- benchmarks/launch_tgi_server.sh +16 -0
- csrc/activation_kernels.cu +118 -0
- csrc/attention/attention_dtypes.h +7 -0
- csrc/attention/attention_generic.cuh +64 -0
- csrc/attention/attention_kernels.cu +951 -0
- csrc/attention/attention_utils.cuh +56 -0
- csrc/attention/dtype_bfloat16.cuh +451 -0
- csrc/attention/dtype_float16.cuh +502 -0
- csrc/attention/dtype_float32.cuh +273 -0
- csrc/attention/dtype_fp8_e5m2.cuh +35 -0
- csrc/cache.h +36 -0
- csrc/cache_kernels.cu +474 -0
- csrc/cuda_compat.h +28 -0
- csrc/cuda_utils.h +10 -0
- csrc/cuda_utils_kernels.cu +35 -0
- csrc/custom_all_reduce.cu +148 -0
- csrc/custom_all_reduce.cuh +562 -0
- csrc/custom_all_reduce_test.cu +284 -0
- csrc/dispatch_utils.h +37 -0
- csrc/layernorm_kernels.cu +120 -0
- csrc/moe_align_block_size_kernels.cu +108 -0
- csrc/ops.h +130 -0
- csrc/pos_encoding_kernels.cu +130 -0
- csrc/punica/LICENSE +217 -0
- csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu +4 -0
.buildkite/run-benchmarks.sh
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script is run by buildkite to run the benchmarks and upload the results to buildkite
|
2 |
+
|
3 |
+
set -ex
|
4 |
+
set -o pipefail
|
5 |
+
|
6 |
+
# cd into parent directory of this file
|
7 |
+
cd "$(dirname "${BASH_SOURCE[0]}")/.."
|
8 |
+
|
9 |
+
(wget && curl) || (apt-get update && apt-get install -y wget curl)
|
10 |
+
|
11 |
+
# run benchmarks and upload the result to buildkite
|
12 |
+
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
|
13 |
+
bench_latency_exit_code=$?
|
14 |
+
|
15 |
+
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
|
16 |
+
bench_throughput_exit_code=$?
|
17 |
+
|
18 |
+
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
|
19 |
+
server_pid=$!
|
20 |
+
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
21 |
+
|
22 |
+
# wait for server to start, timeout after 600 seconds
|
23 |
+
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
24 |
+
python3 benchmarks/benchmark_serving.py \
|
25 |
+
--dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
|
26 |
+
--model meta-llama/Llama-2-7b-chat-hf \
|
27 |
+
--num-prompts 20 \
|
28 |
+
--endpoint /v1/completions \
|
29 |
+
--tokenizer meta-llama/Llama-2-7b-chat-hf 2>&1 | tee benchmark_serving.txt
|
30 |
+
bench_serving_exit_code=$?
|
31 |
+
kill $server_pid
|
32 |
+
|
33 |
+
# write the results into a markdown file
|
34 |
+
echo "### Latency Benchmarks" >> benchmark_results.md
|
35 |
+
sed -n '1p' benchmark_latency.txt >> benchmark_results.md # first line
|
36 |
+
echo "" >> benchmark_results.md
|
37 |
+
sed -n '$p' benchmark_latency.txt >> benchmark_results.md # last line
|
38 |
+
|
39 |
+
echo "### Throughput Benchmarks" >> benchmark_results.md
|
40 |
+
sed -n '1p' benchmark_throughput.txt >> benchmark_results.md # first line
|
41 |
+
echo "" >> benchmark_results.md
|
42 |
+
sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line
|
43 |
+
|
44 |
+
echo "### Serving Benchmarks" >> benchmark_results.md
|
45 |
+
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
|
46 |
+
echo "" >> benchmark_results.md
|
47 |
+
tail -n 5 benchmark_serving.txt >> benchmark_results.md # last 5 lines
|
48 |
+
|
49 |
+
# upload the results to buildkite
|
50 |
+
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
|
51 |
+
|
52 |
+
# exit with the exit code of the benchmarks
|
53 |
+
if [ $bench_latency_exit_code -ne 0 ]; then
|
54 |
+
exit $bench_latency_exit_code
|
55 |
+
fi
|
56 |
+
|
57 |
+
if [ $bench_throughput_exit_code -ne 0 ]; then
|
58 |
+
exit $bench_throughput_exit_code
|
59 |
+
fi
|
60 |
+
|
61 |
+
if [ $bench_serving_exit_code -ne 0 ]; then
|
62 |
+
exit $bench_serving_exit_code
|
63 |
+
fi
|
.buildkite/test-pipeline.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# In this file, you can add more tests to run either by adding a new step or
|
2 |
+
# adding a new command to an existing step. See different options here for examples.
|
3 |
+
# This script will be feed into Jinja template in `test-template.j2` to generate
|
4 |
+
# the final pipeline yaml file.
|
5 |
+
|
6 |
+
steps:
|
7 |
+
- label: Regression Test
|
8 |
+
command: pytest -v -s test_regression.py
|
9 |
+
working_dir: "/vllm-workspace/tests" # optional
|
10 |
+
|
11 |
+
- label: AsyncEngine Test
|
12 |
+
command: pytest -v -s async_engine
|
13 |
+
|
14 |
+
- label: Distributed Test
|
15 |
+
command: pytest -v -s test_comm_ops.py
|
16 |
+
working_dir: "/vllm-workspace/tests/distributed"
|
17 |
+
num_gpus: 2 # only support 1 or 2 for now.
|
18 |
+
|
19 |
+
- label: Engine Test
|
20 |
+
command: pytest -v -s engine
|
21 |
+
|
22 |
+
- label: Entrypoints Test
|
23 |
+
command: pytest -v -s entrypoints
|
24 |
+
|
25 |
+
- label: Kernels Test
|
26 |
+
command: pytest -v -s kernels
|
27 |
+
soft_fail: true
|
28 |
+
|
29 |
+
- label: Models Test
|
30 |
+
commands:
|
31 |
+
- pytest -v -s models --forked
|
32 |
+
soft_fail: true
|
33 |
+
|
34 |
+
- label: Prefix Caching Test
|
35 |
+
commands:
|
36 |
+
- pytest -v -s prefix_caching
|
37 |
+
|
38 |
+
- label: Samplers Test
|
39 |
+
command: pytest -v -s samplers --forked
|
40 |
+
|
41 |
+
- label: Worker Test
|
42 |
+
command: pytest -v -s worker
|
43 |
+
|
44 |
+
- label: LoRA Test
|
45 |
+
command: pytest -v -s lora
|
46 |
+
|
47 |
+
- label: Benchmarks
|
48 |
+
working_dir: "/vllm-workspace/.buildkite"
|
49 |
+
commands:
|
50 |
+
- pip install aiohttp
|
51 |
+
- bash run-benchmarks.sh
|
.buildkite/test-template.j2
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %}
|
2 |
+
{% set default_num_gpu = 1 %}
|
3 |
+
{% set default_working_dir = "/vllm-workspace/tests" %}
|
4 |
+
|
5 |
+
steps:
|
6 |
+
- label: ":docker: build image"
|
7 |
+
commands:
|
8 |
+
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
9 |
+
- "docker push {{ docker_image }}"
|
10 |
+
env:
|
11 |
+
DOCKER_BUILDKIT: "1"
|
12 |
+
retry:
|
13 |
+
automatic:
|
14 |
+
- exit_status: -1 # Agent was lost
|
15 |
+
limit: 5
|
16 |
+
- wait
|
17 |
+
|
18 |
+
{% for step in steps %}
|
19 |
+
- label: "{{ step.label }}"
|
20 |
+
agents:
|
21 |
+
queue: kubernetes
|
22 |
+
soft_fail: {{ step.soft_fail or false }}
|
23 |
+
retry:
|
24 |
+
automatic:
|
25 |
+
- exit_status: -1 # Agent was lost
|
26 |
+
limit: 5
|
27 |
+
plugins:
|
28 |
+
- kubernetes:
|
29 |
+
podSpec:
|
30 |
+
volumes:
|
31 |
+
- name: dshm
|
32 |
+
emptyDir:
|
33 |
+
medium: Memory
|
34 |
+
containers:
|
35 |
+
- image: "{{ docker_image }}"
|
36 |
+
command: ["bash"]
|
37 |
+
args:
|
38 |
+
- "-c"
|
39 |
+
- "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'"
|
40 |
+
resources:
|
41 |
+
requests:
|
42 |
+
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
|
43 |
+
limits:
|
44 |
+
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
|
45 |
+
env:
|
46 |
+
- name: HF_TOKEN
|
47 |
+
valueFrom:
|
48 |
+
secretKeyRef:
|
49 |
+
name: hf-token-secret
|
50 |
+
key: token
|
51 |
+
volumeMounts:
|
52 |
+
- mountPath: /dev/shm
|
53 |
+
name: dshm
|
54 |
+
{% endfor %}
|
.dockerignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
vllm/*.so
|
.github/workflows/publish.yml
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This workflow will upload a Python Package to Release asset
|
2 |
+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
|
3 |
+
|
4 |
+
name: Create Release
|
5 |
+
|
6 |
+
on:
|
7 |
+
push:
|
8 |
+
tags:
|
9 |
+
- v*
|
10 |
+
|
11 |
+
# Needed to create release and upload assets
|
12 |
+
permissions:
|
13 |
+
contents: write
|
14 |
+
|
15 |
+
jobs:
|
16 |
+
release:
|
17 |
+
# Retrieve tag and create release
|
18 |
+
name: Create Release
|
19 |
+
runs-on: ubuntu-latest
|
20 |
+
outputs:
|
21 |
+
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
22 |
+
steps:
|
23 |
+
- name: Checkout
|
24 |
+
uses: actions/checkout@v3
|
25 |
+
|
26 |
+
- name: Extract branch info
|
27 |
+
shell: bash
|
28 |
+
run: |
|
29 |
+
echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
|
30 |
+
|
31 |
+
- name: Create Release
|
32 |
+
id: create_release
|
33 |
+
uses: "actions/github-script@v6"
|
34 |
+
env:
|
35 |
+
RELEASE_TAG: ${{ env.release_tag }}
|
36 |
+
with:
|
37 |
+
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
38 |
+
script: |
|
39 |
+
const script = require('.github/workflows/scripts/create_release.js')
|
40 |
+
await script(github, context, core)
|
41 |
+
|
42 |
+
wheel:
|
43 |
+
name: Build Wheel
|
44 |
+
runs-on: ${{ matrix.os }}
|
45 |
+
needs: release
|
46 |
+
|
47 |
+
strategy:
|
48 |
+
fail-fast: false
|
49 |
+
matrix:
|
50 |
+
os: ['ubuntu-20.04']
|
51 |
+
python-version: ['3.8', '3.9', '3.10', '3.11']
|
52 |
+
pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt.
|
53 |
+
cuda-version: ['11.8', '12.1']
|
54 |
+
|
55 |
+
steps:
|
56 |
+
- name: Checkout
|
57 |
+
uses: actions/checkout@v3
|
58 |
+
|
59 |
+
- name: Set up Linux Env
|
60 |
+
if: ${{ runner.os == 'Linux' }}
|
61 |
+
run: |
|
62 |
+
bash -x .github/workflows/scripts/env.sh
|
63 |
+
|
64 |
+
- name: Set up Python
|
65 |
+
uses: actions/setup-python@v4
|
66 |
+
with:
|
67 |
+
python-version: ${{ matrix.python-version }}
|
68 |
+
|
69 |
+
- name: Install CUDA ${{ matrix.cuda-version }}
|
70 |
+
run: |
|
71 |
+
bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
|
72 |
+
|
73 |
+
- name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
|
74 |
+
run: |
|
75 |
+
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
|
76 |
+
|
77 |
+
- name: Build wheel
|
78 |
+
shell: bash
|
79 |
+
run: |
|
80 |
+
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
81 |
+
wheel_name=$(ls dist/*whl | xargs -n 1 basename)
|
82 |
+
asset_name=${wheel_name//"linux"/"manylinux1"}
|
83 |
+
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
84 |
+
echo "asset_name=${asset_name}" >> $GITHUB_ENV
|
85 |
+
|
86 |
+
- name: Upload Release Asset
|
87 |
+
uses: actions/upload-release-asset@v1
|
88 |
+
env:
|
89 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
90 |
+
with:
|
91 |
+
upload_url: ${{ needs.release.outputs.upload_url }}
|
92 |
+
asset_path: ./dist/${{ env.wheel_name }}
|
93 |
+
asset_name: ${{ env.asset_name }}
|
94 |
+
asset_content_type: application/*
|
95 |
+
|
96 |
+
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
|
97 |
+
# - name: Publish package
|
98 |
+
# uses: pypa/gh-action-pypi-publish@release/v1.8
|
99 |
+
# with:
|
100 |
+
# repository-url: https://test.pypi.org/legacy/
|
101 |
+
# password: ${{ secrets.PYPI_API_TOKEN }}
|
102 |
+
# skip-existing: true
|
.github/workflows/ruff.yml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: ruff
|
2 |
+
|
3 |
+
on:
|
4 |
+
# Trigger the workflow on push or pull request,
|
5 |
+
# but only for the main branch
|
6 |
+
push:
|
7 |
+
branches:
|
8 |
+
- main
|
9 |
+
pull_request:
|
10 |
+
branches:
|
11 |
+
- main
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
ruff:
|
15 |
+
runs-on: ubuntu-latest
|
16 |
+
strategy:
|
17 |
+
matrix:
|
18 |
+
python-version: ["3.10"]
|
19 |
+
steps:
|
20 |
+
- uses: actions/checkout@v2
|
21 |
+
- name: Set up Python ${{ matrix.python-version }}
|
22 |
+
uses: actions/setup-python@v2
|
23 |
+
with:
|
24 |
+
python-version: ${{ matrix.python-version }}
|
25 |
+
- name: Install dependencies
|
26 |
+
run: |
|
27 |
+
python -m pip install --upgrade pip
|
28 |
+
pip install ruff==0.1.5
|
29 |
+
- name: Analysing the code with ruff
|
30 |
+
run: |
|
31 |
+
ruff vllm tests
|
.github/workflows/scripts/build.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
python_executable=python$1
|
4 |
+
cuda_home=/usr/local/cuda-$2
|
5 |
+
|
6 |
+
# Update paths
|
7 |
+
PATH=${cuda_home}/bin:$PATH
|
8 |
+
LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
|
9 |
+
|
10 |
+
# Install requirements
|
11 |
+
$python_executable -m pip install wheel packaging
|
12 |
+
$python_executable -m pip install -r requirements.txt
|
13 |
+
|
14 |
+
# Limit the number of parallel jobs to avoid OOM
|
15 |
+
export MAX_JOBS=1
|
16 |
+
# Make sure punica is built for the release (for LoRA)
|
17 |
+
export VLLM_INSTALL_PUNICA_KERNELS=1
|
18 |
+
|
19 |
+
# Build
|
20 |
+
$python_executable setup.py bdist_wheel --dist-dir=dist
|
.github/workflows/scripts/create_release.js
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Uses Github's API to create the release and wait for result.
|
2 |
+
// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
|
3 |
+
|
4 |
+
module.exports = async (github, context, core) => {
|
5 |
+
try {
|
6 |
+
const response = await github.rest.repos.createRelease({
|
7 |
+
draft: false,
|
8 |
+
generate_release_notes: true,
|
9 |
+
name: process.env.RELEASE_TAG,
|
10 |
+
owner: context.repo.owner,
|
11 |
+
prerelease: false,
|
12 |
+
repo: context.repo.repo,
|
13 |
+
tag_name: process.env.RELEASE_TAG,
|
14 |
+
});
|
15 |
+
|
16 |
+
core.setOutput('upload_url', response.data.upload_url);
|
17 |
+
} catch (error) {
|
18 |
+
core.setFailed(error.message);
|
19 |
+
}
|
20 |
+
}
|
.github/workflows/scripts/cuda-install.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Replace '.' with '-' ex: 11.8 -> 11-8
|
4 |
+
cuda_version=$(echo $1 | tr "." "-")
|
5 |
+
# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
|
6 |
+
OS=$(echo $2 | tr -d ".\-")
|
7 |
+
|
8 |
+
# Installs CUDA
|
9 |
+
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
|
10 |
+
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
11 |
+
rm cuda-keyring_1.1-1_all.deb
|
12 |
+
sudo apt -qq update
|
13 |
+
sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
|
14 |
+
sudo apt clean
|
15 |
+
|
16 |
+
# Test nvcc
|
17 |
+
PATH=/usr/local/cuda-$1/bin:${PATH}
|
18 |
+
nvcc --version
|
19 |
+
|
20 |
+
# Log gcc, g++, c++ versions
|
21 |
+
gcc --version
|
22 |
+
g++ --version
|
23 |
+
c++ --version
|
.github/workflows/scripts/env.sh
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# This file installs common linux environment tools
|
4 |
+
|
5 |
+
export LANG C.UTF-8
|
6 |
+
|
7 |
+
# python_version=$1
|
8 |
+
|
9 |
+
sudo apt-get update && \
|
10 |
+
sudo apt-get install -y --no-install-recommends \
|
11 |
+
software-properties-common \
|
12 |
+
|
13 |
+
sudo apt-get install -y --no-install-recommends \
|
14 |
+
build-essential \
|
15 |
+
apt-utils \
|
16 |
+
ca-certificates \
|
17 |
+
wget \
|
18 |
+
git \
|
19 |
+
vim \
|
20 |
+
libssl-dev \
|
21 |
+
curl \
|
22 |
+
unzip \
|
23 |
+
unrar \
|
24 |
+
cmake \
|
25 |
+
net-tools \
|
26 |
+
sudo \
|
27 |
+
autotools-dev \
|
28 |
+
rsync \
|
29 |
+
jq \
|
30 |
+
openssh-server \
|
31 |
+
tmux \
|
32 |
+
screen \
|
33 |
+
htop \
|
34 |
+
pdsh \
|
35 |
+
openssh-client \
|
36 |
+
lshw \
|
37 |
+
dmidecode \
|
38 |
+
util-linux \
|
39 |
+
automake \
|
40 |
+
autoconf \
|
41 |
+
libtool \
|
42 |
+
net-tools \
|
43 |
+
pciutils \
|
44 |
+
libpci-dev \
|
45 |
+
libaio-dev \
|
46 |
+
libcap2 \
|
47 |
+
libtinfo5 \
|
48 |
+
fakeroot \
|
49 |
+
devscripts \
|
50 |
+
debhelper \
|
51 |
+
nfs-common
|
52 |
+
|
53 |
+
# Remove github bloat files to free up disk space
|
54 |
+
sudo rm -rf "/usr/local/share/boost"
|
55 |
+
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
56 |
+
sudo rm -rf "/usr/share/dotnet"
|
.github/workflows/scripts/pytorch-install.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
python_executable=python$1
|
4 |
+
pytorch_version=$2
|
5 |
+
cuda_version=$3
|
6 |
+
|
7 |
+
# Install torch
|
8 |
+
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
|
9 |
+
$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --extra-index-url https://download.pytorch.org/whl/cu${cuda_version//./}
|
10 |
+
|
11 |
+
# Print version information
|
12 |
+
$python_executable --version
|
13 |
+
$python_executable -c "import torch; print('PyTorch:', torch.__version__)"
|
14 |
+
$python_executable -c "import torch; print('CUDA:', torch.version.cuda)"
|
15 |
+
$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
.github/workflows/yapf.yml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: yapf
|
2 |
+
|
3 |
+
on:
|
4 |
+
# Trigger the workflow on push or pull request,
|
5 |
+
# but only for the main branch
|
6 |
+
push:
|
7 |
+
branches:
|
8 |
+
- main
|
9 |
+
pull_request:
|
10 |
+
branches:
|
11 |
+
- main
|
12 |
+
jobs:
|
13 |
+
yapf:
|
14 |
+
runs-on: ubuntu-latest
|
15 |
+
strategy:
|
16 |
+
matrix:
|
17 |
+
python-version: ["3.10"]
|
18 |
+
steps:
|
19 |
+
- uses: actions/checkout@v2
|
20 |
+
- name: Set up Python ${{ matrix.python-version }}
|
21 |
+
uses: actions/setup-python@v2
|
22 |
+
with:
|
23 |
+
python-version: ${{ matrix.python-version }}
|
24 |
+
- name: Install dependencies
|
25 |
+
run: |
|
26 |
+
python -m pip install --upgrade pip
|
27 |
+
pip install yapf==0.32.0
|
28 |
+
pip install toml==0.10.2
|
29 |
+
- name: Running yapf
|
30 |
+
run: |
|
31 |
+
yapf --diff --recursive .
|
.gitignore
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
.idea/
|
161 |
+
|
162 |
+
# VSCode
|
163 |
+
.vscode/
|
164 |
+
|
165 |
+
# DS Store
|
166 |
+
.DS_Store
|
167 |
+
|
168 |
+
# Results
|
169 |
+
*.csv
|
170 |
+
|
171 |
+
# Python pickle files
|
172 |
+
*.pkl
|
173 |
+
|
174 |
+
# Sphinx documentation
|
175 |
+
_build/
|
176 |
+
|
177 |
+
# vim swap files
|
178 |
+
*.swo
|
179 |
+
*.swp
|
180 |
+
|
181 |
+
# hip files generated by PyTorch
|
182 |
+
*.hip
|
183 |
+
*_hip*
|
184 |
+
|
185 |
+
# Benchmark dataset
|
186 |
+
*.json
|
.readthedocs.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Read the Docs configuration file
|
2 |
+
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
3 |
+
|
4 |
+
version: 2
|
5 |
+
|
6 |
+
build:
|
7 |
+
os: ubuntu-22.04
|
8 |
+
tools:
|
9 |
+
python: "3.8"
|
10 |
+
|
11 |
+
sphinx:
|
12 |
+
configuration: docs/source/conf.py
|
13 |
+
|
14 |
+
# If using Sphinx, optionally build your docs in additional formats such as PDF
|
15 |
+
formats:
|
16 |
+
- pdf
|
17 |
+
|
18 |
+
# Optionally declare the Python requirements required to build your docs
|
19 |
+
python:
|
20 |
+
install:
|
21 |
+
- requirements: docs/requirements-docs.txt
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to vLLM
|
2 |
+
|
3 |
+
Thank you for your interest in contributing to vLLM!
|
4 |
+
Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large.
|
5 |
+
There are several ways you can contribute to the project:
|
6 |
+
|
7 |
+
- Identify and report any issues or bugs.
|
8 |
+
- Request or add a new model.
|
9 |
+
- Suggest or implement new features.
|
10 |
+
|
11 |
+
However, remember that contributions aren't just about code.
|
12 |
+
We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions.
|
13 |
+
|
14 |
+
Finally, one of the most impactful ways to support us is by raising awareness about vLLM.
|
15 |
+
Talk about it in your blog posts, highlighting how it's driving your incredible projects.
|
16 |
+
Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository.
|
17 |
+
|
18 |
+
|
19 |
+
## Setup for development
|
20 |
+
|
21 |
+
### Build from source
|
22 |
+
|
23 |
+
```bash
|
24 |
+
pip install -r requirements.txt
|
25 |
+
pip install -e . # This may take several minutes.
|
26 |
+
```
|
27 |
+
|
28 |
+
### Testing
|
29 |
+
|
30 |
+
```bash
|
31 |
+
pip install -r requirements-dev.txt
|
32 |
+
|
33 |
+
# Static type checking
|
34 |
+
mypy
|
35 |
+
# Unit tests
|
36 |
+
pytest tests/
|
37 |
+
```
|
38 |
+
**Note:** Currently, the repository does not pass the mypy tests.
|
39 |
+
|
40 |
+
|
41 |
+
## Contributing Guidelines
|
42 |
+
|
43 |
+
### Issue Reporting
|
44 |
+
|
45 |
+
If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it.
|
46 |
+
If not, please file a new issue, providing as much relevant information as possible.
|
47 |
+
|
48 |
+
### Coding Style Guide
|
49 |
+
|
50 |
+
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
51 |
+
|
52 |
+
We include a formatting script [`format.sh`](./format.sh) to format the code.
|
53 |
+
|
54 |
+
### Pull Requests
|
55 |
+
|
56 |
+
When submitting a pull request:
|
57 |
+
|
58 |
+
1. Make sure your code has been rebased on top of the latest commit on the main branch.
|
59 |
+
2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
|
60 |
+
3. Include a detailed description of the changes in the pull request.
|
61 |
+
Explain why you made the changes you did.
|
62 |
+
If your pull request fixes an open issue, please include a reference to it in the description.
|
63 |
+
|
64 |
+
### Code Reviews
|
65 |
+
|
66 |
+
All submissions, including submissions by project members, require a code review.
|
67 |
+
To make the review process as smooth as possible, please:
|
68 |
+
|
69 |
+
1. Keep your changes as concise as possible.
|
70 |
+
If your pull request involves multiple unrelated changes, consider splitting it into separate pull requests.
|
71 |
+
2. Respond to all comments within a reasonable time frame.
|
72 |
+
If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
|
73 |
+
|
74 |
+
### Thank You
|
75 |
+
|
76 |
+
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM.
|
77 |
+
Your contributions make vLLM a great tool for everyone!
|
Dockerfile
CHANGED
@@ -94,4 +94,4 @@ COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
|
94 |
COPY vllm vllm
|
95 |
|
96 |
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
97 |
-
#################### OPENAI API SERVER ####################
|
|
|
94 |
COPY vllm vllm
|
95 |
|
96 |
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
97 |
+
#################### OPENAI API SERVER ####################
|
Dockerfile.rocm
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# default base image
|
2 |
+
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
3 |
+
|
4 |
+
FROM $BASE_IMAGE
|
5 |
+
|
6 |
+
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
7 |
+
|
8 |
+
RUN echo "Base image is $BASE_IMAGE"
|
9 |
+
|
10 |
+
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
|
11 |
+
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
12 |
+
|
13 |
+
# this does not always work for all rocm versions
|
14 |
+
RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
|
15 |
+
echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"
|
16 |
+
|
17 |
+
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
18 |
+
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
19 |
+
|
20 |
+
ARG FA_BRANCH="3d2b6f5"
|
21 |
+
RUN echo "FA_BRANCH is $FA_BRANCH"
|
22 |
+
|
23 |
+
# Install some basic utilities
|
24 |
+
RUN apt-get update && apt-get install python3 python3-pip -y
|
25 |
+
|
26 |
+
# Install some basic utilities
|
27 |
+
RUN apt-get update && apt-get install -y \
|
28 |
+
curl \
|
29 |
+
ca-certificates \
|
30 |
+
sudo \
|
31 |
+
git \
|
32 |
+
bzip2 \
|
33 |
+
libx11-6 \
|
34 |
+
build-essential \
|
35 |
+
wget \
|
36 |
+
unzip \
|
37 |
+
nvidia-cuda-toolkit \
|
38 |
+
tmux \
|
39 |
+
&& rm -rf /var/lib/apt/lists/*
|
40 |
+
|
41 |
+
### Mount Point ###
|
42 |
+
# When launching the container, mount the code directory to /app
|
43 |
+
ARG APP_MOUNT=/app
|
44 |
+
VOLUME [ ${APP_MOUNT} ]
|
45 |
+
WORKDIR ${APP_MOUNT}
|
46 |
+
|
47 |
+
RUN python3 -m pip install --upgrade pip
|
48 |
+
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
49 |
+
|
50 |
+
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
51 |
+
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
52 |
+
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
53 |
+
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
54 |
+
|
55 |
+
# Install ROCm flash-attention
|
56 |
+
RUN mkdir libs \
|
57 |
+
&& cd libs \
|
58 |
+
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
|
59 |
+
&& cd flash-attention \
|
60 |
+
&& git checkout ${FA_BRANCH} \
|
61 |
+
&& git submodule update --init \
|
62 |
+
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
|
63 |
+
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
|
64 |
+
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
|
65 |
+
&& python3 setup.py install \
|
66 |
+
&& cd ..
|
67 |
+
|
68 |
+
COPY ./ /app/vllm
|
69 |
+
|
70 |
+
RUN python3 -m pip install --upgrade pip
|
71 |
+
RUN python3 -m pip install xformers==0.0.23 --no-deps
|
72 |
+
|
73 |
+
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
74 |
+
# Manually removed it so that later steps of numpy upgrade can continue
|
75 |
+
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
76 |
+
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
77 |
+
|
78 |
+
RUN cd /app \
|
79 |
+
&& cd vllm \
|
80 |
+
&& pip install -U -r requirements-rocm.txt \
|
81 |
+
&& bash patch_xformers.rocm.sh \
|
82 |
+
&& python3 setup.py install \
|
83 |
+
&& cd ..
|
84 |
+
|
85 |
+
RUN python3 -m pip install --upgrade pip
|
86 |
+
RUN python3 -m pip install --no-cache-dir ray[all]
|
87 |
+
|
88 |
+
CMD ["/bin/bash"]
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
MANIFEST.in
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include LICENSE
|
2 |
+
include requirements.txt
|
3 |
+
|
4 |
+
recursive-include csrc *
|
README.md
CHANGED
@@ -1,10 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<picture>
|
3 |
+
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-dark.png">
|
4 |
+
<img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-light.png" width=55%>
|
5 |
+
</picture>
|
6 |
+
</p>
|
7 |
+
|
8 |
+
<h3 align="center">
|
9 |
+
Easy, fast, and cheap LLM serving for everyone
|
10 |
+
</h3>
|
11 |
+
|
12 |
+
<p align="center">
|
13 |
+
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
14 |
+
|
15 |
+
</p>
|
16 |
+
|
17 |
+
---
|
18 |
+
|
19 |
+
**The Second vLLM Bay Area Meetup (Jan 31st 5pm-7:30pm PT)**
|
20 |
+
|
21 |
+
We are thrilled to announce our second vLLM Meetup!
|
22 |
+
The vLLM team will share recent updates and roadmap.
|
23 |
+
We will also have vLLM collaborators from IBM coming up to the stage to discuss their insights on LLM optimizations.
|
24 |
+
Please register [here](https://lu.ma/ygxbpzhl) and join us!
|
25 |
+
|
26 |
---
|
27 |
+
|
28 |
+
*Latest News* 🔥
|
29 |
+
- [2024/01] Added ROCm 6.0 support to vLLM.
|
30 |
+
- [2023/12] Added ROCm 5.7 support to vLLM.
|
31 |
+
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
32 |
+
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
33 |
+
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
34 |
+
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
35 |
+
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
36 |
+
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
37 |
+
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
38 |
+
|
39 |
---
|
40 |
+
## About
|
41 |
+
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
42 |
+
|
43 |
+
vLLM is fast with:
|
44 |
+
|
45 |
+
- State-of-the-art serving throughput
|
46 |
+
- Efficient management of attention key and value memory with **PagedAttention**
|
47 |
+
- Continuous batching of incoming requests
|
48 |
+
- Fast model execution with CUDA/HIP graph
|
49 |
+
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629)
|
50 |
+
- Optimized CUDA kernels
|
51 |
+
|
52 |
+
vLLM is flexible and easy to use with:
|
53 |
+
|
54 |
+
- Seamless integration with popular Hugging Face models
|
55 |
+
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
56 |
+
- Tensor parallelism support for distributed inference
|
57 |
+
- Streaming outputs
|
58 |
+
- OpenAI-compatible API server
|
59 |
+
- Support NVIDIA GPUs and AMD GPUs
|
60 |
+
|
61 |
+
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
62 |
+
|
63 |
+
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
64 |
+
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
|
65 |
+
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
66 |
+
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
|
67 |
+
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
|
68 |
+
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
69 |
+
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
70 |
+
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
71 |
+
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
72 |
+
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
73 |
+
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
74 |
+
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
75 |
+
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
76 |
+
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
|
77 |
+
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
78 |
+
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
79 |
+
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
80 |
+
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
81 |
+
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
|
82 |
+
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
83 |
+
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
|
84 |
+
|
85 |
+
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
86 |
+
|
87 |
+
```bash
|
88 |
+
pip install vllm
|
89 |
+
```
|
90 |
+
|
91 |
+
## Getting Started
|
92 |
+
|
93 |
+
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
|
94 |
+
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
|
95 |
+
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
|
96 |
+
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
|
97 |
+
|
98 |
+
## Contributing
|
99 |
+
|
100 |
+
We welcome and value any contributions and collaborations.
|
101 |
+
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
102 |
+
|
103 |
+
## Citation
|
104 |
|
105 |
+
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
106 |
+
```bibtex
|
107 |
+
@inproceedings{kwon2023efficient,
|
108 |
+
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
|
109 |
+
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
|
110 |
+
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
|
111 |
+
year={2023}
|
112 |
+
}
|
113 |
+
```
|
benchmarks/README.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Benchmarking vLLM
|
2 |
+
|
3 |
+
## Downloading the ShareGPT dataset
|
4 |
+
|
5 |
+
You can download the dataset by running:
|
6 |
+
```bash
|
7 |
+
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
8 |
+
```
|
benchmarks/benchmark_latency.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Benchmark the latency of processing a single batch of requests."""
|
2 |
+
import argparse
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from vllm import LLM, SamplingParams
|
12 |
+
|
13 |
+
|
14 |
+
def main(args: argparse.Namespace):
|
15 |
+
print(args)
|
16 |
+
|
17 |
+
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
18 |
+
# the engine will automatically process the request in multiple batches.
|
19 |
+
llm = LLM(
|
20 |
+
model=args.model,
|
21 |
+
tokenizer=args.tokenizer,
|
22 |
+
quantization=args.quantization,
|
23 |
+
tensor_parallel_size=args.tensor_parallel_size,
|
24 |
+
trust_remote_code=args.trust_remote_code,
|
25 |
+
dtype=args.dtype,
|
26 |
+
enforce_eager=args.enforce_eager,
|
27 |
+
kv_cache_dtype=args.kv_cache_dtype,
|
28 |
+
)
|
29 |
+
|
30 |
+
sampling_params = SamplingParams(
|
31 |
+
n=args.n,
|
32 |
+
temperature=0.0 if args.use_beam_search else 1.0,
|
33 |
+
top_p=1.0,
|
34 |
+
use_beam_search=args.use_beam_search,
|
35 |
+
ignore_eos=True,
|
36 |
+
max_tokens=args.output_len,
|
37 |
+
)
|
38 |
+
print(sampling_params)
|
39 |
+
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
|
40 |
+
|
41 |
+
def run_to_completion(profile_dir: Optional[str] = None):
|
42 |
+
if profile_dir:
|
43 |
+
with torch.profiler.profile(
|
44 |
+
activities=[
|
45 |
+
torch.profiler.ProfilerActivity.CPU,
|
46 |
+
torch.profiler.ProfilerActivity.CUDA,
|
47 |
+
],
|
48 |
+
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
49 |
+
str(profile_dir))) as p:
|
50 |
+
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
51 |
+
sampling_params=sampling_params,
|
52 |
+
use_tqdm=False)
|
53 |
+
print(p.key_averages())
|
54 |
+
else:
|
55 |
+
start_time = time.perf_counter()
|
56 |
+
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
57 |
+
sampling_params=sampling_params,
|
58 |
+
use_tqdm=False)
|
59 |
+
end_time = time.perf_counter()
|
60 |
+
latency = end_time - start_time
|
61 |
+
return latency
|
62 |
+
|
63 |
+
print("Warming up...")
|
64 |
+
run_to_completion(profile_dir=None)
|
65 |
+
|
66 |
+
if args.profile:
|
67 |
+
profile_dir = args.profile_result_dir
|
68 |
+
if not profile_dir:
|
69 |
+
profile_dir = Path(
|
70 |
+
"."
|
71 |
+
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
72 |
+
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
73 |
+
run_to_completion(profile_dir=args.profile_result_dir)
|
74 |
+
return
|
75 |
+
|
76 |
+
# Benchmark.
|
77 |
+
latencies = []
|
78 |
+
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
79 |
+
latencies.append(run_to_completion(profile_dir=None))
|
80 |
+
print(f'Avg latency: {np.mean(latencies)} seconds')
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
parser = argparse.ArgumentParser(
|
85 |
+
description='Benchmark the latency of processing a single batch of '
|
86 |
+
'requests till completion.')
|
87 |
+
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
88 |
+
parser.add_argument('--tokenizer', type=str, default=None)
|
89 |
+
parser.add_argument('--quantization',
|
90 |
+
'-q',
|
91 |
+
choices=['awq', 'gptq', 'squeezellm', None],
|
92 |
+
default=None)
|
93 |
+
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
94 |
+
parser.add_argument('--input-len', type=int, default=32)
|
95 |
+
parser.add_argument('--output-len', type=int, default=128)
|
96 |
+
parser.add_argument('--batch-size', type=int, default=8)
|
97 |
+
parser.add_argument('--n',
|
98 |
+
type=int,
|
99 |
+
default=1,
|
100 |
+
help='Number of generated sequences per prompt.')
|
101 |
+
parser.add_argument('--use-beam-search', action='store_true')
|
102 |
+
parser.add_argument('--num-iters',
|
103 |
+
type=int,
|
104 |
+
default=3,
|
105 |
+
help='Number of iterations to run.')
|
106 |
+
parser.add_argument('--trust-remote-code',
|
107 |
+
action='store_true',
|
108 |
+
help='trust remote code from huggingface')
|
109 |
+
parser.add_argument(
|
110 |
+
'--dtype',
|
111 |
+
type=str,
|
112 |
+
default='auto',
|
113 |
+
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
114 |
+
help='data type for model weights and activations. '
|
115 |
+
'The "auto" option will use FP16 precision '
|
116 |
+
'for FP32 and FP16 models, and BF16 precision '
|
117 |
+
'for BF16 models.')
|
118 |
+
parser.add_argument('--enforce-eager',
|
119 |
+
action='store_true',
|
120 |
+
help='enforce eager mode and disable CUDA graph')
|
121 |
+
parser.add_argument(
|
122 |
+
"--kv-cache-dtype",
|
123 |
+
type=str,
|
124 |
+
choices=['auto', 'fp8_e5m2'],
|
125 |
+
default='auto',
|
126 |
+
help=
|
127 |
+
'Data type for kv cache storage. If "auto", will use model data type.')
|
128 |
+
parser.add_argument(
|
129 |
+
'--profile',
|
130 |
+
action='store_true',
|
131 |
+
help='profile the generation process of a single batch')
|
132 |
+
parser.add_argument(
|
133 |
+
'--profile-result-dir',
|
134 |
+
type=str,
|
135 |
+
default=None,
|
136 |
+
help=('path to save the pytorch profiler output. Can be visualized '
|
137 |
+
'with ui.perfetto.dev or Tensorboard.'))
|
138 |
+
args = parser.parse_args()
|
139 |
+
main(args)
|
benchmarks/benchmark_serving.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Benchmark online serving throughput.
|
2 |
+
|
3 |
+
On the server side, run one of the following commands:
|
4 |
+
(vLLM backend)
|
5 |
+
python -m vllm.entrypoints.api_server \
|
6 |
+
--model <your_model> --swap-space 16 \
|
7 |
+
--disable-log-requests
|
8 |
+
|
9 |
+
(TGI backend)
|
10 |
+
./launch_hf_server.sh <your_model>
|
11 |
+
|
12 |
+
On the client side, run:
|
13 |
+
python benchmarks/benchmark_serving.py \
|
14 |
+
--backend <backend> \
|
15 |
+
--tokenizer <your_model> --dataset <target_dataset> \
|
16 |
+
--request-rate <request_rate>
|
17 |
+
"""
|
18 |
+
import argparse
|
19 |
+
import asyncio
|
20 |
+
import json
|
21 |
+
import random
|
22 |
+
import time
|
23 |
+
from typing import AsyncGenerator, List, Tuple
|
24 |
+
|
25 |
+
import aiohttp
|
26 |
+
import numpy as np
|
27 |
+
from tqdm.asyncio import tqdm
|
28 |
+
from transformers import PreTrainedTokenizerBase
|
29 |
+
from vllm.transformers_utils.tokenizer import get_tokenizer
|
30 |
+
|
31 |
+
# (prompt len, output len, latency)
|
32 |
+
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
|
33 |
+
|
34 |
+
|
35 |
+
def sample_requests(
|
36 |
+
dataset_path: str,
|
37 |
+
num_requests: int,
|
38 |
+
tokenizer: PreTrainedTokenizerBase,
|
39 |
+
) -> List[Tuple[str, int, int]]:
|
40 |
+
# Load the dataset.
|
41 |
+
with open(dataset_path) as f:
|
42 |
+
dataset = json.load(f)
|
43 |
+
# Filter out the conversations with less than 2 turns.
|
44 |
+
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
45 |
+
# Only keep the first two turns of each conversation.
|
46 |
+
dataset = [(data["conversations"][0]["value"],
|
47 |
+
data["conversations"][1]["value"]) for data in dataset]
|
48 |
+
|
49 |
+
# Tokenize the prompts and completions.
|
50 |
+
prompts = [prompt for prompt, _ in dataset]
|
51 |
+
prompt_token_ids = tokenizer(prompts).input_ids
|
52 |
+
completions = [completion for _, completion in dataset]
|
53 |
+
completion_token_ids = tokenizer(completions).input_ids
|
54 |
+
tokenized_dataset = []
|
55 |
+
for i in range(len(dataset)):
|
56 |
+
output_len = len(completion_token_ids[i])
|
57 |
+
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
58 |
+
|
59 |
+
# Filter out too long sequences.
|
60 |
+
filtered_dataset: List[Tuple[str, int, int]] = []
|
61 |
+
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
62 |
+
prompt_len = len(prompt_token_ids)
|
63 |
+
if prompt_len < 4 or output_len < 4:
|
64 |
+
# Prune too short sequences.
|
65 |
+
# This is because TGI causes errors when the input or output length
|
66 |
+
# is too short.
|
67 |
+
continue
|
68 |
+
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
69 |
+
# Prune too long sequences.
|
70 |
+
continue
|
71 |
+
filtered_dataset.append((prompt, prompt_len, output_len))
|
72 |
+
|
73 |
+
# Sample the requests.
|
74 |
+
sampled_requests = random.sample(filtered_dataset, num_requests)
|
75 |
+
return sampled_requests
|
76 |
+
|
77 |
+
|
78 |
+
async def get_request(
|
79 |
+
input_requests: List[Tuple[str, int, int]],
|
80 |
+
request_rate: float,
|
81 |
+
) -> AsyncGenerator[Tuple[str, int, int], None]:
|
82 |
+
input_requests = iter(input_requests)
|
83 |
+
for request in input_requests:
|
84 |
+
yield request
|
85 |
+
|
86 |
+
if request_rate == float("inf"):
|
87 |
+
# If the request rate is infinity, then we don't need to wait.
|
88 |
+
continue
|
89 |
+
# Sample the request interval from the exponential distribution.
|
90 |
+
interval = np.random.exponential(1.0 / request_rate)
|
91 |
+
# The next request will be sent after the interval.
|
92 |
+
await asyncio.sleep(interval)
|
93 |
+
|
94 |
+
|
95 |
+
async def send_request(backend: str, model: str, api_url: str, prompt: str,
|
96 |
+
prompt_len: int, output_len: int, best_of: int,
|
97 |
+
use_beam_search: bool, pbar: tqdm) -> None:
|
98 |
+
request_start_time = time.perf_counter()
|
99 |
+
|
100 |
+
headers = {"User-Agent": "Benchmark Client"}
|
101 |
+
if backend == "vllm":
|
102 |
+
pload = {
|
103 |
+
"prompt": prompt,
|
104 |
+
"n": 1,
|
105 |
+
"best_of": best_of,
|
106 |
+
"use_beam_search": use_beam_search,
|
107 |
+
"temperature": 0.0 if use_beam_search else 1.0,
|
108 |
+
"top_p": 1.0,
|
109 |
+
"max_tokens": output_len,
|
110 |
+
"ignore_eos": True,
|
111 |
+
"stream": False,
|
112 |
+
}
|
113 |
+
if model is not None:
|
114 |
+
pload["model"] = model
|
115 |
+
elif backend == "tgi":
|
116 |
+
assert not use_beam_search
|
117 |
+
params = {
|
118 |
+
"best_of": best_of,
|
119 |
+
"max_new_tokens": output_len,
|
120 |
+
"do_sample": True,
|
121 |
+
}
|
122 |
+
pload = {
|
123 |
+
"inputs": prompt,
|
124 |
+
"parameters": params,
|
125 |
+
}
|
126 |
+
else:
|
127 |
+
raise ValueError(f"Unknown backend: {backend}")
|
128 |
+
|
129 |
+
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
130 |
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
131 |
+
while True:
|
132 |
+
async with session.post(api_url, headers=headers,
|
133 |
+
json=pload) as response:
|
134 |
+
chunks = []
|
135 |
+
async for chunk, _ in response.content.iter_chunks():
|
136 |
+
chunks.append(chunk)
|
137 |
+
output = b"".join(chunks).decode("utf-8")
|
138 |
+
output = json.loads(output)
|
139 |
+
|
140 |
+
# Re-send the request if it failed.
|
141 |
+
if "error" not in output:
|
142 |
+
break
|
143 |
+
|
144 |
+
request_end_time = time.perf_counter()
|
145 |
+
request_latency = request_end_time - request_start_time
|
146 |
+
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
147 |
+
pbar.update(1)
|
148 |
+
|
149 |
+
|
150 |
+
async def benchmark(
|
151 |
+
backend: str,
|
152 |
+
model: str,
|
153 |
+
api_url: str,
|
154 |
+
input_requests: List[Tuple[str, int, int]],
|
155 |
+
best_of: int,
|
156 |
+
use_beam_search: bool,
|
157 |
+
request_rate: float,
|
158 |
+
) -> None:
|
159 |
+
tasks: List[asyncio.Task] = []
|
160 |
+
pbar = tqdm(total=len(input_requests))
|
161 |
+
async for request in get_request(input_requests, request_rate):
|
162 |
+
prompt, prompt_len, output_len = request
|
163 |
+
task = asyncio.create_task(
|
164 |
+
send_request(backend, model, api_url, prompt, prompt_len,
|
165 |
+
output_len, best_of, use_beam_search, pbar))
|
166 |
+
tasks.append(task)
|
167 |
+
await asyncio.gather(*tasks)
|
168 |
+
pbar.close()
|
169 |
+
|
170 |
+
|
171 |
+
def main(args: argparse.Namespace):
|
172 |
+
print(args)
|
173 |
+
random.seed(args.seed)
|
174 |
+
np.random.seed(args.seed)
|
175 |
+
|
176 |
+
api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}"
|
177 |
+
tokenizer = get_tokenizer(args.tokenizer,
|
178 |
+
trust_remote_code=args.trust_remote_code)
|
179 |
+
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
180 |
+
|
181 |
+
benchmark_start_time = time.perf_counter()
|
182 |
+
asyncio.run(
|
183 |
+
benchmark(args.backend, args.model, api_url, input_requests,
|
184 |
+
args.best_of, args.use_beam_search, args.request_rate))
|
185 |
+
benchmark_end_time = time.perf_counter()
|
186 |
+
benchmark_time = benchmark_end_time - benchmark_start_time
|
187 |
+
print(f"Total time: {benchmark_time:.2f} s")
|
188 |
+
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
|
189 |
+
|
190 |
+
# Compute the latency statistics.
|
191 |
+
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
|
192 |
+
print(f"Average latency: {avg_latency:.2f} s")
|
193 |
+
avg_per_token_latency = np.mean([
|
194 |
+
latency / (prompt_len + output_len)
|
195 |
+
for prompt_len, output_len, latency in REQUEST_LATENCY
|
196 |
+
])
|
197 |
+
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
198 |
+
avg_per_output_token_latency = np.mean(
|
199 |
+
[latency / output_len for _, output_len, latency in REQUEST_LATENCY])
|
200 |
+
print("Average latency per output token: "
|
201 |
+
f"{avg_per_output_token_latency:.2f} s")
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == "__main__":
|
205 |
+
parser = argparse.ArgumentParser(
|
206 |
+
description="Benchmark the online serving throughput.")
|
207 |
+
parser.add_argument("--backend",
|
208 |
+
type=str,
|
209 |
+
default="vllm",
|
210 |
+
choices=["vllm", "tgi"])
|
211 |
+
parser.add_argument("--protocol",
|
212 |
+
type=str,
|
213 |
+
default="http",
|
214 |
+
choices=["http", "https"])
|
215 |
+
parser.add_argument("--host", type=str, default="localhost")
|
216 |
+
parser.add_argument("--port", type=int, default=8000)
|
217 |
+
parser.add_argument("--endpoint", type=str, default="/generate")
|
218 |
+
parser.add_argument("--model", type=str, default=None)
|
219 |
+
parser.add_argument("--dataset",
|
220 |
+
type=str,
|
221 |
+
required=True,
|
222 |
+
help="Path to the dataset.")
|
223 |
+
parser.add_argument("--tokenizer",
|
224 |
+
type=str,
|
225 |
+
required=True,
|
226 |
+
help="Name or path of the tokenizer.")
|
227 |
+
parser.add_argument("--best-of",
|
228 |
+
type=int,
|
229 |
+
default=1,
|
230 |
+
help="Generates `best_of` sequences per prompt and "
|
231 |
+
"returns the best one.")
|
232 |
+
parser.add_argument("--use-beam-search", action="store_true")
|
233 |
+
parser.add_argument("--num-prompts",
|
234 |
+
type=int,
|
235 |
+
default=1000,
|
236 |
+
help="Number of prompts to process.")
|
237 |
+
parser.add_argument("--request-rate",
|
238 |
+
type=float,
|
239 |
+
default=float("inf"),
|
240 |
+
help="Number of requests per second. If this is inf, "
|
241 |
+
"then all the requests are sent at time 0. "
|
242 |
+
"Otherwise, we use Poisson process to synthesize "
|
243 |
+
"the request arrival times.")
|
244 |
+
parser.add_argument("--seed", type=int, default=0)
|
245 |
+
parser.add_argument('--trust-remote-code',
|
246 |
+
action='store_true',
|
247 |
+
help='trust remote code from huggingface')
|
248 |
+
args = parser.parse_args()
|
249 |
+
main(args)
|
benchmarks/benchmark_throughput.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Benchmark offline inference throughput."""
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
import time
|
6 |
+
from typing import List, Optional, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
10 |
+
PreTrainedTokenizerBase)
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
def sample_requests(
|
15 |
+
dataset_path: str,
|
16 |
+
num_requests: int,
|
17 |
+
tokenizer: PreTrainedTokenizerBase,
|
18 |
+
fixed_output_len: Optional[int],
|
19 |
+
) -> List[Tuple[str, int, int]]:
|
20 |
+
if fixed_output_len is not None and fixed_output_len < 4:
|
21 |
+
raise ValueError("output_len too small")
|
22 |
+
|
23 |
+
# Load the dataset.
|
24 |
+
with open(dataset_path) as f:
|
25 |
+
dataset = json.load(f)
|
26 |
+
# Filter out the conversations with less than 2 turns.
|
27 |
+
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
28 |
+
# Only keep the first two turns of each conversation.
|
29 |
+
dataset = [(data["conversations"][0]["value"],
|
30 |
+
data["conversations"][1]["value"]) for data in dataset]
|
31 |
+
|
32 |
+
# Tokenize the prompts and completions.
|
33 |
+
prompts = [prompt for prompt, _ in dataset]
|
34 |
+
prompt_token_ids = tokenizer(prompts).input_ids
|
35 |
+
completions = [completion for _, completion in dataset]
|
36 |
+
completion_token_ids = tokenizer(completions).input_ids
|
37 |
+
tokenized_dataset = []
|
38 |
+
for i in range(len(dataset)):
|
39 |
+
output_len = len(completion_token_ids[i])
|
40 |
+
if fixed_output_len is not None:
|
41 |
+
output_len = fixed_output_len
|
42 |
+
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
43 |
+
|
44 |
+
# Filter out too long sequences.
|
45 |
+
filtered_dataset: List[Tuple[str, int, int]] = []
|
46 |
+
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
47 |
+
prompt_len = len(prompt_token_ids)
|
48 |
+
if prompt_len < 4 or output_len < 4:
|
49 |
+
# Prune too short sequences.
|
50 |
+
continue
|
51 |
+
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
52 |
+
# Prune too long sequences.
|
53 |
+
continue
|
54 |
+
filtered_dataset.append((prompt, prompt_len, output_len))
|
55 |
+
|
56 |
+
# Sample the requests.
|
57 |
+
sampled_requests = random.sample(filtered_dataset, num_requests)
|
58 |
+
return sampled_requests
|
59 |
+
|
60 |
+
|
61 |
+
def run_vllm(
|
62 |
+
requests: List[Tuple[str, int, int]],
|
63 |
+
model: str,
|
64 |
+
tokenizer: str,
|
65 |
+
quantization: Optional[str],
|
66 |
+
tensor_parallel_size: int,
|
67 |
+
seed: int,
|
68 |
+
n: int,
|
69 |
+
use_beam_search: bool,
|
70 |
+
trust_remote_code: bool,
|
71 |
+
dtype: str,
|
72 |
+
max_model_len: Optional[int],
|
73 |
+
enforce_eager: bool,
|
74 |
+
kv_cache_dtype: str,
|
75 |
+
) -> float:
|
76 |
+
from vllm import LLM, SamplingParams
|
77 |
+
llm = LLM(
|
78 |
+
model=model,
|
79 |
+
tokenizer=tokenizer,
|
80 |
+
quantization=quantization,
|
81 |
+
tensor_parallel_size=tensor_parallel_size,
|
82 |
+
seed=seed,
|
83 |
+
trust_remote_code=trust_remote_code,
|
84 |
+
dtype=dtype,
|
85 |
+
max_model_len=max_model_len,
|
86 |
+
enforce_eager=enforce_eager,
|
87 |
+
kv_cache_dtype=kv_cache_dtype,
|
88 |
+
)
|
89 |
+
|
90 |
+
# Add the requests to the engine.
|
91 |
+
for prompt, _, output_len in requests:
|
92 |
+
sampling_params = SamplingParams(
|
93 |
+
n=n,
|
94 |
+
temperature=0.0 if use_beam_search else 1.0,
|
95 |
+
top_p=1.0,
|
96 |
+
use_beam_search=use_beam_search,
|
97 |
+
ignore_eos=True,
|
98 |
+
max_tokens=output_len,
|
99 |
+
)
|
100 |
+
# FIXME(woosuk): Do not use internal method.
|
101 |
+
llm._add_request(
|
102 |
+
prompt=prompt,
|
103 |
+
prompt_token_ids=None,
|
104 |
+
sampling_params=sampling_params,
|
105 |
+
)
|
106 |
+
|
107 |
+
start = time.perf_counter()
|
108 |
+
# FIXME(woosuk): Do not use internal method.
|
109 |
+
llm._run_engine(use_tqdm=True)
|
110 |
+
end = time.perf_counter()
|
111 |
+
return end - start
|
112 |
+
|
113 |
+
|
114 |
+
def run_hf(
|
115 |
+
requests: List[Tuple[str, int, int]],
|
116 |
+
model: str,
|
117 |
+
tokenizer: PreTrainedTokenizerBase,
|
118 |
+
n: int,
|
119 |
+
use_beam_search: bool,
|
120 |
+
max_batch_size: int,
|
121 |
+
trust_remote_code: bool,
|
122 |
+
) -> float:
|
123 |
+
assert not use_beam_search
|
124 |
+
llm = AutoModelForCausalLM.from_pretrained(
|
125 |
+
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
126 |
+
if llm.config.model_type == "llama":
|
127 |
+
# To enable padding in the HF backend.
|
128 |
+
tokenizer.pad_token = tokenizer.eos_token
|
129 |
+
llm = llm.cuda()
|
130 |
+
|
131 |
+
pbar = tqdm(total=len(requests))
|
132 |
+
start = time.perf_counter()
|
133 |
+
batch: List[str] = []
|
134 |
+
max_prompt_len = 0
|
135 |
+
max_output_len = 0
|
136 |
+
for i in range(len(requests)):
|
137 |
+
prompt, prompt_len, output_len = requests[i]
|
138 |
+
# Add the prompt to the batch.
|
139 |
+
batch.append(prompt)
|
140 |
+
max_prompt_len = max(max_prompt_len, prompt_len)
|
141 |
+
max_output_len = max(max_output_len, output_len)
|
142 |
+
if len(batch) < max_batch_size and i != len(requests) - 1:
|
143 |
+
# Check if we can add more requests to the batch.
|
144 |
+
_, next_prompt_len, next_output_len = requests[i + 1]
|
145 |
+
if (max(max_prompt_len, next_prompt_len) +
|
146 |
+
max(max_output_len, next_output_len)) <= 2048:
|
147 |
+
# We can add more requests to the batch.
|
148 |
+
continue
|
149 |
+
|
150 |
+
# Generate the sequences.
|
151 |
+
input_ids = tokenizer(batch, return_tensors="pt",
|
152 |
+
padding=True).input_ids
|
153 |
+
llm_outputs = llm.generate(
|
154 |
+
input_ids=input_ids.cuda(),
|
155 |
+
do_sample=not use_beam_search,
|
156 |
+
num_return_sequences=n,
|
157 |
+
temperature=1.0,
|
158 |
+
top_p=1.0,
|
159 |
+
use_cache=True,
|
160 |
+
max_new_tokens=max_output_len,
|
161 |
+
)
|
162 |
+
# Include the decoding time.
|
163 |
+
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
164 |
+
pbar.update(len(batch))
|
165 |
+
|
166 |
+
# Clear the batch.
|
167 |
+
batch = []
|
168 |
+
max_prompt_len = 0
|
169 |
+
max_output_len = 0
|
170 |
+
end = time.perf_counter()
|
171 |
+
return end - start
|
172 |
+
|
173 |
+
|
174 |
+
def run_mii(
|
175 |
+
requests: List[Tuple[str, int, int]],
|
176 |
+
model: str,
|
177 |
+
tensor_parallel_size: int,
|
178 |
+
output_len: int,
|
179 |
+
) -> float:
|
180 |
+
from mii import pipeline
|
181 |
+
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
|
182 |
+
prompts = [prompt for prompt, _, _ in requests]
|
183 |
+
|
184 |
+
start = time.perf_counter()
|
185 |
+
llm(prompts, max_new_tokens=output_len)
|
186 |
+
end = time.perf_counter()
|
187 |
+
return end - start
|
188 |
+
|
189 |
+
|
190 |
+
def main(args: argparse.Namespace):
|
191 |
+
print(args)
|
192 |
+
random.seed(args.seed)
|
193 |
+
|
194 |
+
# Sample the requests.
|
195 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
196 |
+
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
197 |
+
if args.dataset is None:
|
198 |
+
# Synthesize a prompt with the given input length.
|
199 |
+
prompt = "hi" * (args.input_len - 1)
|
200 |
+
requests = [(prompt, args.input_len, args.output_len)
|
201 |
+
for _ in range(args.num_prompts)]
|
202 |
+
else:
|
203 |
+
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
204 |
+
args.output_len)
|
205 |
+
|
206 |
+
if args.backend == "vllm":
|
207 |
+
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
208 |
+
args.quantization, args.tensor_parallel_size,
|
209 |
+
args.seed, args.n, args.use_beam_search,
|
210 |
+
args.trust_remote_code, args.dtype,
|
211 |
+
args.max_model_len, args.enforce_eager,
|
212 |
+
args.kv_cache_dtype)
|
213 |
+
elif args.backend == "hf":
|
214 |
+
assert args.tensor_parallel_size == 1
|
215 |
+
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
216 |
+
args.use_beam_search, args.hf_max_batch_size,
|
217 |
+
args.trust_remote_code)
|
218 |
+
elif args.backend == "mii":
|
219 |
+
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
220 |
+
args.output_len)
|
221 |
+
else:
|
222 |
+
raise ValueError(f"Unknown backend: {args.backend}")
|
223 |
+
total_num_tokens = sum(prompt_len + output_len
|
224 |
+
for _, prompt_len, output_len in requests)
|
225 |
+
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
226 |
+
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
227 |
+
|
228 |
+
|
229 |
+
if __name__ == "__main__":
|
230 |
+
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
231 |
+
parser.add_argument("--backend",
|
232 |
+
type=str,
|
233 |
+
choices=["vllm", "hf", "mii"],
|
234 |
+
default="vllm")
|
235 |
+
parser.add_argument("--dataset",
|
236 |
+
type=str,
|
237 |
+
default=None,
|
238 |
+
help="Path to the dataset.")
|
239 |
+
parser.add_argument("--input-len",
|
240 |
+
type=int,
|
241 |
+
default=None,
|
242 |
+
help="Input prompt length for each request")
|
243 |
+
parser.add_argument("--output-len",
|
244 |
+
type=int,
|
245 |
+
default=None,
|
246 |
+
help="Output length for each request. Overrides the "
|
247 |
+
"output length from the dataset.")
|
248 |
+
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
249 |
+
parser.add_argument("--tokenizer", type=str, default=None)
|
250 |
+
parser.add_argument('--quantization',
|
251 |
+
'-q',
|
252 |
+
choices=['awq', 'gptq', 'squeezellm', None],
|
253 |
+
default=None)
|
254 |
+
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
255 |
+
parser.add_argument("--n",
|
256 |
+
type=int,
|
257 |
+
default=1,
|
258 |
+
help="Number of generated sequences per prompt.")
|
259 |
+
parser.add_argument("--use-beam-search", action="store_true")
|
260 |
+
parser.add_argument("--num-prompts",
|
261 |
+
type=int,
|
262 |
+
default=1000,
|
263 |
+
help="Number of prompts to process.")
|
264 |
+
parser.add_argument("--seed", type=int, default=0)
|
265 |
+
parser.add_argument("--hf-max-batch-size",
|
266 |
+
type=int,
|
267 |
+
default=None,
|
268 |
+
help="Maximum batch size for HF backend.")
|
269 |
+
parser.add_argument('--trust-remote-code',
|
270 |
+
action='store_true',
|
271 |
+
help='trust remote code from huggingface')
|
272 |
+
parser.add_argument(
|
273 |
+
'--max-model-len',
|
274 |
+
type=int,
|
275 |
+
default=None,
|
276 |
+
help='Maximum length of a sequence (including prompt and output). '
|
277 |
+
'If None, will be derived from the model.')
|
278 |
+
parser.add_argument(
|
279 |
+
'--dtype',
|
280 |
+
type=str,
|
281 |
+
default='auto',
|
282 |
+
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
283 |
+
help='data type for model weights and activations. '
|
284 |
+
'The "auto" option will use FP16 precision '
|
285 |
+
'for FP32 and FP16 models, and BF16 precision '
|
286 |
+
'for BF16 models.')
|
287 |
+
parser.add_argument("--enforce-eager",
|
288 |
+
action="store_true",
|
289 |
+
help="enforce eager execution")
|
290 |
+
parser.add_argument(
|
291 |
+
"--kv-cache-dtype",
|
292 |
+
type=str,
|
293 |
+
choices=["auto", "fp8_e5m2"],
|
294 |
+
default="auto",
|
295 |
+
help=
|
296 |
+
'Data type for kv cache storage. If "auto", will use model data type.')
|
297 |
+
args = parser.parse_args()
|
298 |
+
if args.tokenizer is None:
|
299 |
+
args.tokenizer = args.model
|
300 |
+
if args.dataset is None:
|
301 |
+
assert args.input_len is not None
|
302 |
+
assert args.output_len is not None
|
303 |
+
else:
|
304 |
+
assert args.input_len is None
|
305 |
+
|
306 |
+
if args.backend == "vllm":
|
307 |
+
if args.hf_max_batch_size is not None:
|
308 |
+
raise ValueError("HF max batch size is only for HF backend.")
|
309 |
+
elif args.backend == "hf":
|
310 |
+
if args.hf_max_batch_size is None:
|
311 |
+
raise ValueError("HF max batch size is required for HF backend.")
|
312 |
+
if args.quantization is not None:
|
313 |
+
raise ValueError("Quantization is only for vLLM backend.")
|
314 |
+
elif args.backend == "mii":
|
315 |
+
if args.dtype != "auto":
|
316 |
+
raise ValueError("dtype must be auto for MII backend.")
|
317 |
+
if args.n != 1:
|
318 |
+
raise ValueError("n must be 1 for MII backend.")
|
319 |
+
if args.use_beam_search:
|
320 |
+
raise ValueError("Beam search is not supported for MII backend.")
|
321 |
+
if args.quantization is not None:
|
322 |
+
raise ValueError("Quantization is only for vLLM backend.")
|
323 |
+
if args.hf_max_batch_size is not None:
|
324 |
+
raise ValueError("HF max batch size is only for HF backend.")
|
325 |
+
if args.tokenizer != args.model:
|
326 |
+
raise ValueError("Tokenizer must be the same as the model for MII "
|
327 |
+
"backend.")
|
328 |
+
main(args)
|
benchmarks/kernels/benchmark_paged_attention.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
9 |
+
from vllm._C import ops
|
10 |
+
|
11 |
+
NUM_BLOCKS = 1024
|
12 |
+
PARTITION_SIZE = 512
|
13 |
+
|
14 |
+
|
15 |
+
@torch.inference_mode()
|
16 |
+
def main(
|
17 |
+
version: str,
|
18 |
+
num_seqs: int,
|
19 |
+
context_len: int,
|
20 |
+
num_query_heads: int,
|
21 |
+
num_kv_heads: int,
|
22 |
+
head_size: int,
|
23 |
+
use_alibi: bool,
|
24 |
+
block_size: int,
|
25 |
+
dtype: torch.dtype,
|
26 |
+
seed: int,
|
27 |
+
do_profile: bool,
|
28 |
+
kv_cache_dtype: Optional[str] = None,
|
29 |
+
) -> None:
|
30 |
+
random.seed(seed)
|
31 |
+
torch.random.manual_seed(seed)
|
32 |
+
torch.cuda.manual_seed(seed)
|
33 |
+
|
34 |
+
scale = float(1.0 / (head_size**0.5))
|
35 |
+
query = torch.empty(num_seqs,
|
36 |
+
num_query_heads,
|
37 |
+
head_size,
|
38 |
+
dtype=dtype,
|
39 |
+
device="cuda")
|
40 |
+
query.uniform_(-scale, scale)
|
41 |
+
|
42 |
+
assert num_query_heads % num_kv_heads == 0
|
43 |
+
alibi_slopes = None
|
44 |
+
if use_alibi:
|
45 |
+
alibi_slopes = torch.randn(num_query_heads,
|
46 |
+
dtype=torch.float,
|
47 |
+
device="cuda")
|
48 |
+
|
49 |
+
context_lens = [context_len for _ in range(num_seqs)]
|
50 |
+
max_context_len = max(context_lens)
|
51 |
+
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
52 |
+
|
53 |
+
# Create the block tables.
|
54 |
+
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
55 |
+
block_tables = []
|
56 |
+
for _ in range(num_seqs):
|
57 |
+
block_table = [
|
58 |
+
random.randint(0, NUM_BLOCKS - 1)
|
59 |
+
for _ in range(max_num_blocks_per_seq)
|
60 |
+
]
|
61 |
+
block_tables.append(block_table)
|
62 |
+
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
63 |
+
|
64 |
+
# Create the KV cache.
|
65 |
+
key_caches, value_caches = create_kv_caches_with_random(
|
66 |
+
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
|
67 |
+
dtype)
|
68 |
+
key_cache, value_cache = key_caches[0], value_caches[0]
|
69 |
+
|
70 |
+
# Prepare for the paged attention kernel.
|
71 |
+
output = torch.empty_like(query)
|
72 |
+
if version == "v2":
|
73 |
+
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
74 |
+
PARTITION_SIZE)
|
75 |
+
tmp_output = torch.empty(
|
76 |
+
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
77 |
+
dtype=output.dtype,
|
78 |
+
device=output.device,
|
79 |
+
)
|
80 |
+
exp_sums = torch.empty(
|
81 |
+
size=(num_seqs, num_query_heads, num_partitions),
|
82 |
+
dtype=torch.float32,
|
83 |
+
device=output.device,
|
84 |
+
)
|
85 |
+
max_logits = torch.empty_like(exp_sums)
|
86 |
+
|
87 |
+
def run_benchmark(num_iters: int, profile: bool = False) -> float:
|
88 |
+
torch.cuda.synchronize()
|
89 |
+
if profile:
|
90 |
+
torch.cuda.cudart().cudaProfilerStart()
|
91 |
+
start_time = time.perf_counter()
|
92 |
+
|
93 |
+
for _ in range(num_iters):
|
94 |
+
if version == "v1":
|
95 |
+
ops.paged_attention_v1(
|
96 |
+
output,
|
97 |
+
query,
|
98 |
+
key_cache,
|
99 |
+
value_cache,
|
100 |
+
num_kv_heads,
|
101 |
+
scale,
|
102 |
+
block_tables,
|
103 |
+
context_lens,
|
104 |
+
block_size,
|
105 |
+
max_context_len,
|
106 |
+
alibi_slopes,
|
107 |
+
kv_cache_dtype,
|
108 |
+
)
|
109 |
+
elif version == "v2":
|
110 |
+
ops.paged_attention_v2(
|
111 |
+
output,
|
112 |
+
exp_sums,
|
113 |
+
max_logits,
|
114 |
+
tmp_output,
|
115 |
+
query,
|
116 |
+
key_cache,
|
117 |
+
value_cache,
|
118 |
+
num_kv_heads,
|
119 |
+
scale,
|
120 |
+
block_tables,
|
121 |
+
context_lens,
|
122 |
+
block_size,
|
123 |
+
max_context_len,
|
124 |
+
alibi_slopes,
|
125 |
+
kv_cache_dtype,
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
raise ValueError(f"Invalid version: {version}")
|
129 |
+
torch.cuda.synchronize()
|
130 |
+
|
131 |
+
end_time = time.perf_counter()
|
132 |
+
if profile:
|
133 |
+
torch.cuda.cudart().cudaProfilerStart()
|
134 |
+
return (end_time - start_time) / num_iters
|
135 |
+
|
136 |
+
# Warmup.
|
137 |
+
print("Warming up...")
|
138 |
+
run_benchmark(num_iters=3, profile=False)
|
139 |
+
|
140 |
+
# Benchmark.
|
141 |
+
if do_profile:
|
142 |
+
latency = run_benchmark(num_iters=1, profile=True)
|
143 |
+
else:
|
144 |
+
latency = run_benchmark(num_iters=100, profile=False)
|
145 |
+
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == '__main__':
|
149 |
+
parser = argparse.ArgumentParser(
|
150 |
+
description="Benchmark the paged attention kernel.")
|
151 |
+
parser.add_argument("--version",
|
152 |
+
type=str,
|
153 |
+
choices=["v1", "v2"],
|
154 |
+
default="v2")
|
155 |
+
parser.add_argument("--batch-size", type=int, default=8)
|
156 |
+
parser.add_argument("--context-len", type=int, default=4096)
|
157 |
+
parser.add_argument("--num-query-heads", type=int, default=64)
|
158 |
+
parser.add_argument("--num-kv-heads", type=int, default=8)
|
159 |
+
parser.add_argument("--head-size",
|
160 |
+
type=int,
|
161 |
+
choices=[64, 80, 96, 112, 128, 256],
|
162 |
+
default=128)
|
163 |
+
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
164 |
+
parser.add_argument("--use-alibi", action="store_true")
|
165 |
+
parser.add_argument("--dtype",
|
166 |
+
type=str,
|
167 |
+
choices=["half", "bfloat16", "float"],
|
168 |
+
default="half")
|
169 |
+
parser.add_argument("--seed", type=int, default=0)
|
170 |
+
parser.add_argument("--profile", action="store_true")
|
171 |
+
parser.add_argument(
|
172 |
+
"--kv-cache-dtype",
|
173 |
+
type=str,
|
174 |
+
choices=["auto", "fp8_e5m2"],
|
175 |
+
default="auto",
|
176 |
+
help=
|
177 |
+
'Data type for kv cache storage. If "auto", will use model data type.')
|
178 |
+
args = parser.parse_args()
|
179 |
+
print(args)
|
180 |
+
|
181 |
+
if args.num_query_heads % args.num_kv_heads != 0:
|
182 |
+
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
183 |
+
main(
|
184 |
+
version=args.version,
|
185 |
+
num_seqs=args.batch_size,
|
186 |
+
context_len=args.context_len,
|
187 |
+
num_query_heads=args.num_query_heads,
|
188 |
+
num_kv_heads=args.num_kv_heads,
|
189 |
+
head_size=args.head_size,
|
190 |
+
block_size=args.block_size,
|
191 |
+
use_alibi=args.use_alibi,
|
192 |
+
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
193 |
+
seed=args.seed,
|
194 |
+
do_profile=args.profile,
|
195 |
+
kv_cache_dtype=args.kv_cache_dtype,
|
196 |
+
)
|
benchmarks/launch_tgi_server.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
PORT=8000
|
4 |
+
MODEL=$1
|
5 |
+
TOKENS=$2
|
6 |
+
|
7 |
+
docker run --gpus all --shm-size 1g -p $PORT:80 \
|
8 |
+
-v $PWD/data:/data \
|
9 |
+
ghcr.io/huggingface/text-generation-inference:0.8 \
|
10 |
+
--model-id $MODEL \
|
11 |
+
--sharded false \
|
12 |
+
--max-input-length 1024 \
|
13 |
+
--max-total-tokens 2048 \
|
14 |
+
--max-best-of 5 \
|
15 |
+
--max-concurrent-requests 5000 \
|
16 |
+
--max-batch-total-tokens $TOKENS
|
csrc/activation_kernels.cu
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/cuda/CUDAContext.h>
|
2 |
+
#include <torch/extension.h>
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
|
5 |
+
#include "cuda_compat.h"
|
6 |
+
#include "dispatch_utils.h"
|
7 |
+
|
8 |
+
namespace vllm {
|
9 |
+
|
10 |
+
template<typename T>
|
11 |
+
__device__ __forceinline__ T silu(const T& x) {
|
12 |
+
// x * sigmoid(x)
|
13 |
+
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
14 |
+
}
|
15 |
+
|
16 |
+
template<typename scalar_t>
|
17 |
+
__global__ void silu_and_mul_kernel(
|
18 |
+
scalar_t* __restrict__ out, // [..., d]
|
19 |
+
const scalar_t* __restrict__ input, // [..., 2, d]
|
20 |
+
const int d) {
|
21 |
+
const int64_t token_idx = blockIdx.x;
|
22 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
23 |
+
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
24 |
+
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
25 |
+
out[token_idx * d + idx] = silu(x) * y;
|
26 |
+
}
|
27 |
+
}
|
28 |
+
|
29 |
+
} // namespace vllm
|
30 |
+
|
31 |
+
void silu_and_mul(
|
32 |
+
torch::Tensor& out, // [..., d]
|
33 |
+
torch::Tensor& input) // [..., 2 * d]
|
34 |
+
{
|
35 |
+
int64_t num_tokens = input.numel() / input.size(-1);
|
36 |
+
int d = input.size(-1) / 2;
|
37 |
+
|
38 |
+
dim3 grid(num_tokens);
|
39 |
+
dim3 block(std::min(d, 1024));
|
40 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
41 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
42 |
+
VLLM_DISPATCH_FLOATING_TYPES(
|
43 |
+
input.scalar_type(),
|
44 |
+
"silu_and_mul_kernel",
|
45 |
+
[&] {
|
46 |
+
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
47 |
+
out.data_ptr<scalar_t>(),
|
48 |
+
input.data_ptr<scalar_t>(),
|
49 |
+
d);
|
50 |
+
});
|
51 |
+
}
|
52 |
+
|
53 |
+
namespace vllm {
|
54 |
+
|
55 |
+
// Element-wise activation kernel template.
|
56 |
+
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
57 |
+
__global__ void activation_kernel(
|
58 |
+
scalar_t* __restrict__ out, // [..., d]
|
59 |
+
const scalar_t* __restrict__ input, // [..., d]
|
60 |
+
const int d) {
|
61 |
+
const int64_t token_idx = blockIdx.x;
|
62 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
63 |
+
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
64 |
+
out[token_idx * d + idx] = ACT_FN(x);
|
65 |
+
}
|
66 |
+
}
|
67 |
+
|
68 |
+
} // namespace vllm
|
69 |
+
|
70 |
+
// Launch element-wise activation kernel.
|
71 |
+
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
72 |
+
int d = input.size(-1); \
|
73 |
+
int64_t num_tokens = input.numel() / d; \
|
74 |
+
dim3 grid(num_tokens); \
|
75 |
+
dim3 block(std::min(d, 1024)); \
|
76 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
77 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
78 |
+
VLLM_DISPATCH_FLOATING_TYPES( \
|
79 |
+
input.scalar_type(), \
|
80 |
+
"activation_kernel", \
|
81 |
+
[&] { \
|
82 |
+
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
83 |
+
out.data_ptr<scalar_t>(), \
|
84 |
+
input.data_ptr<scalar_t>(), \
|
85 |
+
d); \
|
86 |
+
});
|
87 |
+
|
88 |
+
namespace vllm {
|
89 |
+
|
90 |
+
template<typename T>
|
91 |
+
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
92 |
+
const float x3 = (float) (x * x * x);
|
93 |
+
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
|
94 |
+
return ((T) 0.5) * x * (((T) 1.0) + t);
|
95 |
+
}
|
96 |
+
|
97 |
+
template<typename T>
|
98 |
+
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
99 |
+
const float f = (float) x;
|
100 |
+
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
101 |
+
return ((T) 0.5) * x * (((T) 1.0) + t);
|
102 |
+
}
|
103 |
+
|
104 |
+
} // namespace vllm
|
105 |
+
|
106 |
+
void gelu_new(
|
107 |
+
torch::Tensor& out, // [..., d]
|
108 |
+
torch::Tensor& input) // [..., d]
|
109 |
+
{
|
110 |
+
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
111 |
+
}
|
112 |
+
|
113 |
+
void gelu_fast(
|
114 |
+
torch::Tensor& out, // [..., d]
|
115 |
+
torch::Tensor& input) // [..., d]
|
116 |
+
{
|
117 |
+
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
118 |
+
}
|
csrc/attention/attention_dtypes.h
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "attention_generic.cuh"
|
4 |
+
#include "dtype_float16.cuh"
|
5 |
+
#include "dtype_float32.cuh"
|
6 |
+
#include "dtype_bfloat16.cuh"
|
7 |
+
#include "dtype_fp8_e5m2.cuh"
|
csrc/attention/attention_generic.cuh
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
3 |
+
* Copyright (c) 2023, The vLLM team.
|
4 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
#pragma once
|
19 |
+
|
20 |
+
#include <stdint.h>
|
21 |
+
|
22 |
+
namespace vllm {
|
23 |
+
|
24 |
+
// A vector type to store Q, K, V elements.
|
25 |
+
template<typename T, int VEC_SIZE>
|
26 |
+
struct Vec {};
|
27 |
+
|
28 |
+
// A vector type to store FP32 accumulators.
|
29 |
+
template<typename T>
|
30 |
+
struct FloatVec {};
|
31 |
+
|
32 |
+
// Template vector operations.
|
33 |
+
template<typename Acc, typename A, typename B>
|
34 |
+
inline __device__ Acc mul(A a, B b);
|
35 |
+
|
36 |
+
template<typename T>
|
37 |
+
inline __device__ float sum(T v);
|
38 |
+
|
39 |
+
template<typename T>
|
40 |
+
inline __device__ float dot(T a, T b) {
|
41 |
+
return sum(mul<T, T, T>(a, b));
|
42 |
+
}
|
43 |
+
|
44 |
+
template<typename A, typename T>
|
45 |
+
inline __device__ float dot(T a, T b) {
|
46 |
+
return sum(mul<A, T, T>(a, b));
|
47 |
+
}
|
48 |
+
|
49 |
+
template<typename T>
|
50 |
+
inline __device__ void zero(T& dst) {
|
51 |
+
constexpr int WORDS = sizeof(T) / 4;
|
52 |
+
union {
|
53 |
+
T raw;
|
54 |
+
uint32_t words[WORDS];
|
55 |
+
} tmp;
|
56 |
+
|
57 |
+
#pragma unroll
|
58 |
+
for (int ii = 0; ii < WORDS; ++ii) {
|
59 |
+
tmp.words[ii] = 0u;
|
60 |
+
}
|
61 |
+
dst = tmp.raw;
|
62 |
+
}
|
63 |
+
|
64 |
+
} // namespace vllm
|
csrc/attention/attention_kernels.cu
ADDED
@@ -0,0 +1,951 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
3 |
+
* Copyright (c) 2023, The vLLM team.
|
4 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
#ifdef USE_ROCM
|
19 |
+
#include <hip/hip_runtime.h>
|
20 |
+
#endif
|
21 |
+
|
22 |
+
#include <torch/extension.h>
|
23 |
+
#include <ATen/cuda/CUDAContext.h>
|
24 |
+
#include <c10/cuda/CUDAGuard.h>
|
25 |
+
|
26 |
+
#include "attention_dtypes.h"
|
27 |
+
#include "attention_utils.cuh"
|
28 |
+
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
29 |
+
|
30 |
+
#include <algorithm>
|
31 |
+
|
32 |
+
#ifndef USE_ROCM
|
33 |
+
#define WARP_SIZE 32
|
34 |
+
#else
|
35 |
+
#define WARP_SIZE warpSize
|
36 |
+
#endif
|
37 |
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
38 |
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
39 |
+
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
40 |
+
|
41 |
+
namespace vllm {
|
42 |
+
|
43 |
+
// Utility function for attention softmax.
|
44 |
+
template<int NUM_WARPS>
|
45 |
+
inline __device__ float block_sum(float* red_smem, float sum) {
|
46 |
+
// Decompose the thread index into warp / lane.
|
47 |
+
int warp = threadIdx.x / WARP_SIZE;
|
48 |
+
int lane = threadIdx.x % WARP_SIZE;
|
49 |
+
|
50 |
+
// Compute the sum per warp.
|
51 |
+
#pragma unroll
|
52 |
+
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
53 |
+
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
54 |
+
}
|
55 |
+
|
56 |
+
// Warp leaders store the data to shared memory.
|
57 |
+
if (lane == 0) {
|
58 |
+
red_smem[warp] = sum;
|
59 |
+
}
|
60 |
+
|
61 |
+
// Make sure the data is in shared memory.
|
62 |
+
__syncthreads();
|
63 |
+
|
64 |
+
// The warps compute the final sums.
|
65 |
+
if (lane < NUM_WARPS) {
|
66 |
+
sum = red_smem[lane];
|
67 |
+
}
|
68 |
+
|
69 |
+
// Parallel reduction inside the warp.
|
70 |
+
#pragma unroll
|
71 |
+
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
72 |
+
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
73 |
+
}
|
74 |
+
|
75 |
+
// Broadcast to other threads.
|
76 |
+
return VLLM_SHFL_SYNC(sum, 0);
|
77 |
+
}
|
78 |
+
|
79 |
+
// TODO(woosuk): Merge the last two dimensions of the grid.
|
80 |
+
// Grid: (num_heads, num_seqs, max_num_partitions).
|
81 |
+
template<
|
82 |
+
typename scalar_t,
|
83 |
+
typename cache_t,
|
84 |
+
int HEAD_SIZE,
|
85 |
+
int BLOCK_SIZE,
|
86 |
+
int NUM_THREADS,
|
87 |
+
bool IS_FP8_E5M2_KV_CACHE,
|
88 |
+
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
89 |
+
__device__ void paged_attention_kernel(
|
90 |
+
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
91 |
+
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
92 |
+
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
93 |
+
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
94 |
+
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
95 |
+
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
96 |
+
const int num_kv_heads, // [num_heads]
|
97 |
+
const float scale,
|
98 |
+
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
99 |
+
const int* __restrict__ context_lens, // [num_seqs]
|
100 |
+
const int max_num_blocks_per_seq,
|
101 |
+
const float* __restrict__ alibi_slopes, // [num_heads]
|
102 |
+
const int q_stride,
|
103 |
+
const int kv_block_stride,
|
104 |
+
const int kv_head_stride) {
|
105 |
+
const int seq_idx = blockIdx.y;
|
106 |
+
const int partition_idx = blockIdx.z;
|
107 |
+
const int max_num_partitions = gridDim.z;
|
108 |
+
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
|
109 |
+
const int context_len = context_lens[seq_idx];
|
110 |
+
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
|
111 |
+
// No work to do. Terminate the thread block.
|
112 |
+
return;
|
113 |
+
}
|
114 |
+
|
115 |
+
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
116 |
+
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
|
117 |
+
|
118 |
+
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
119 |
+
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
120 |
+
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
|
121 |
+
const int num_blocks = end_block_idx - start_block_idx;
|
122 |
+
|
123 |
+
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
124 |
+
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
125 |
+
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
|
126 |
+
const int num_tokens = end_token_idx - start_token_idx;
|
127 |
+
|
128 |
+
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
129 |
+
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
130 |
+
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
131 |
+
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
132 |
+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
133 |
+
const int thread_idx = threadIdx.x;
|
134 |
+
const int warp_idx = thread_idx / WARP_SIZE;
|
135 |
+
const int lane = thread_idx % WARP_SIZE;
|
136 |
+
|
137 |
+
const int head_idx = blockIdx.x;
|
138 |
+
const int num_heads = gridDim.x;
|
139 |
+
const int num_queries_per_kv = num_heads / num_kv_heads;
|
140 |
+
const int kv_head_idx = head_idx / num_queries_per_kv;
|
141 |
+
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
142 |
+
|
143 |
+
// A vector type to store a part of a key or a query.
|
144 |
+
// The vector size is configured in such a way that the threads in a thread group
|
145 |
+
// fetch or compute 16 bytes at a time.
|
146 |
+
// For example, if the size of a thread group is 4 and the data type is half,
|
147 |
+
// then the vector size is 16 / (4 * sizeof(half)) == 2.
|
148 |
+
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
149 |
+
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
150 |
+
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
151 |
+
#ifdef ENABLE_FP8_E5M2
|
152 |
+
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
153 |
+
#endif
|
154 |
+
|
155 |
+
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
156 |
+
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
157 |
+
|
158 |
+
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
|
159 |
+
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
|
160 |
+
|
161 |
+
// Load the query to registers.
|
162 |
+
// Each thread in a thread group has a different part of the query.
|
163 |
+
// For example, if the the thread group size is 4, then the first thread in the group
|
164 |
+
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
165 |
+
// th vectors of the query, and so on.
|
166 |
+
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
167 |
+
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
168 |
+
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
169 |
+
#pragma unroll
|
170 |
+
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
171 |
+
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
172 |
+
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
173 |
+
}
|
174 |
+
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
|
175 |
+
|
176 |
+
// Memory planning.
|
177 |
+
extern __shared__ char shared_mem[];
|
178 |
+
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
|
179 |
+
float* logits = reinterpret_cast<float*>(shared_mem);
|
180 |
+
// Workspace for reduction.
|
181 |
+
__shared__ float red_smem[2 * NUM_WARPS];
|
182 |
+
|
183 |
+
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
184 |
+
// Each thread group fetches x elements from the key at a time.
|
185 |
+
constexpr int x = 16 / sizeof(cache_t);
|
186 |
+
float qk_max = -FLT_MAX;
|
187 |
+
|
188 |
+
// Iterate over the key blocks.
|
189 |
+
// Each warp fetches a block of keys for each iteration.
|
190 |
+
// Each thread group in a warp fetches a key from the block, and computes
|
191 |
+
// dot product with the query.
|
192 |
+
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
193 |
+
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
194 |
+
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
195 |
+
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
196 |
+
// (e.g., kv_block_stride).
|
197 |
+
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
198 |
+
|
199 |
+
// Load a key to registers.
|
200 |
+
// Each thread in a thread group has a different part of the key.
|
201 |
+
// For example, if the the thread group size is 4, then the first thread in the group
|
202 |
+
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
|
203 |
+
// vectors of the key, and so on.
|
204 |
+
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
205 |
+
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
206 |
+
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
207 |
+
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
208 |
+
|
209 |
+
#pragma unroll
|
210 |
+
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
211 |
+
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
212 |
+
+ kv_head_idx * kv_head_stride
|
213 |
+
+ physical_block_offset * x;
|
214 |
+
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
215 |
+
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
216 |
+
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
217 |
+
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
218 |
+
#ifdef ENABLE_FP8_E5M2
|
219 |
+
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
220 |
+
// Vector conversion from Quant_vec to K_vec.
|
221 |
+
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
222 |
+
#else
|
223 |
+
assert(false);
|
224 |
+
#endif
|
225 |
+
} else {
|
226 |
+
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
227 |
+
}
|
228 |
+
}
|
229 |
+
|
230 |
+
// Compute dot product.
|
231 |
+
// This includes a reduction across the threads in the same thread group.
|
232 |
+
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
233 |
+
// Add the ALiBi bias if slopes are given.
|
234 |
+
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
235 |
+
|
236 |
+
if (thread_group_offset == 0) {
|
237 |
+
// Store the partial reductions to shared memory.
|
238 |
+
// NOTE(woosuk): It is required to zero out the masked logits.
|
239 |
+
const bool mask = token_idx >= context_len;
|
240 |
+
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
241 |
+
// Update the max value.
|
242 |
+
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
243 |
+
}
|
244 |
+
}
|
245 |
+
}
|
246 |
+
|
247 |
+
// Perform reduction across the threads in the same warp to get the
|
248 |
+
// max qk value for each "warp" (not across the thread block yet).
|
249 |
+
// The 0-th thread of each thread group already has its max qk value.
|
250 |
+
#pragma unroll
|
251 |
+
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
252 |
+
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
253 |
+
}
|
254 |
+
if (lane == 0) {
|
255 |
+
red_smem[warp_idx] = qk_max;
|
256 |
+
}
|
257 |
+
__syncthreads();
|
258 |
+
|
259 |
+
// TODO(woosuk): Refactor this part.
|
260 |
+
// Get the max qk value for the sequence.
|
261 |
+
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
262 |
+
#pragma unroll
|
263 |
+
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
264 |
+
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
265 |
+
}
|
266 |
+
// Broadcast the max qk value to all threads.
|
267 |
+
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
268 |
+
|
269 |
+
// Get the sum of the exp values.
|
270 |
+
float exp_sum = 0.f;
|
271 |
+
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
272 |
+
float val = __expf(logits[i] - qk_max);
|
273 |
+
logits[i] = val;
|
274 |
+
exp_sum += val;
|
275 |
+
}
|
276 |
+
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
277 |
+
|
278 |
+
// Compute softmax.
|
279 |
+
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
280 |
+
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
281 |
+
logits[i] *= inv_sum;
|
282 |
+
}
|
283 |
+
__syncthreads();
|
284 |
+
|
285 |
+
// If partitioning is enabled, store the max logit and exp_sum.
|
286 |
+
if (USE_PARTITIONING && thread_idx == 0) {
|
287 |
+
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
288 |
+
+ head_idx * max_num_partitions
|
289 |
+
+ partition_idx;
|
290 |
+
*max_logits_ptr = qk_max;
|
291 |
+
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
292 |
+
+ head_idx * max_num_partitions
|
293 |
+
+ partition_idx;
|
294 |
+
*exp_sums_ptr = exp_sum;
|
295 |
+
}
|
296 |
+
|
297 |
+
// Each thread will fetch 16 bytes from the value cache at a time.
|
298 |
+
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
299 |
+
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
300 |
+
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
301 |
+
#ifdef ENABLE_FP8_E5M2
|
302 |
+
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
303 |
+
#endif
|
304 |
+
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
305 |
+
|
306 |
+
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
307 |
+
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
308 |
+
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
309 |
+
|
310 |
+
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
311 |
+
float accs[NUM_ROWS_PER_THREAD];
|
312 |
+
#pragma unroll
|
313 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
314 |
+
accs[i] = 0.f;
|
315 |
+
}
|
316 |
+
|
317 |
+
scalar_t zero_value;
|
318 |
+
zero(zero_value);
|
319 |
+
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
320 |
+
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
321 |
+
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
322 |
+
// (e.g., kv_block_stride).
|
323 |
+
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
324 |
+
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
325 |
+
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
326 |
+
L_vec logits_vec;
|
327 |
+
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
328 |
+
|
329 |
+
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
330 |
+
+ kv_head_idx * kv_head_stride;
|
331 |
+
#pragma unroll
|
332 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
333 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
334 |
+
if (row_idx < HEAD_SIZE) {
|
335 |
+
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
336 |
+
V_vec v_vec;
|
337 |
+
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
338 |
+
#ifdef ENABLE_FP8_E5M2
|
339 |
+
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
340 |
+
// Vector conversion from V_quant_vec to V_vec.
|
341 |
+
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
342 |
+
#else
|
343 |
+
assert(false);
|
344 |
+
#endif
|
345 |
+
} else {
|
346 |
+
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
347 |
+
}
|
348 |
+
if (block_idx == num_context_blocks - 1) {
|
349 |
+
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
350 |
+
// we should explicitly zero out the values since they may contain NaNs.
|
351 |
+
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
352 |
+
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
353 |
+
#pragma unroll
|
354 |
+
for (int j = 0; j < V_VEC_SIZE; j++) {
|
355 |
+
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
356 |
+
}
|
357 |
+
}
|
358 |
+
accs[i] += dot(logits_vec, v_vec);
|
359 |
+
}
|
360 |
+
}
|
361 |
+
}
|
362 |
+
|
363 |
+
// Perform reduction within each warp.
|
364 |
+
#pragma unroll
|
365 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
366 |
+
float acc = accs[i];
|
367 |
+
#pragma unroll
|
368 |
+
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
369 |
+
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
370 |
+
}
|
371 |
+
accs[i] = acc;
|
372 |
+
}
|
373 |
+
|
374 |
+
// NOTE(woosuk): A barrier is required because the shared memory space for logits
|
375 |
+
// is reused for the output.
|
376 |
+
__syncthreads();
|
377 |
+
|
378 |
+
// Perform reduction across warps.
|
379 |
+
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
380 |
+
#pragma unroll
|
381 |
+
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
382 |
+
int mid = i / 2;
|
383 |
+
// Upper warps write to shared memory.
|
384 |
+
if (warp_idx >= mid && warp_idx < i) {
|
385 |
+
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
386 |
+
#pragma unroll
|
387 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
388 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
389 |
+
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
390 |
+
dst[row_idx] = accs[i];
|
391 |
+
}
|
392 |
+
}
|
393 |
+
}
|
394 |
+
__syncthreads();
|
395 |
+
|
396 |
+
// Lower warps update the output.
|
397 |
+
if (warp_idx < mid) {
|
398 |
+
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
399 |
+
#pragma unroll
|
400 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
401 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
402 |
+
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
403 |
+
accs[i] += src[row_idx];
|
404 |
+
}
|
405 |
+
}
|
406 |
+
}
|
407 |
+
__syncthreads();
|
408 |
+
}
|
409 |
+
|
410 |
+
// Write the final output.
|
411 |
+
if (warp_idx == 0) {
|
412 |
+
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
413 |
+
+ head_idx * max_num_partitions * HEAD_SIZE
|
414 |
+
+ partition_idx * HEAD_SIZE;
|
415 |
+
#pragma unroll
|
416 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
417 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
418 |
+
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
419 |
+
from_float(*(out_ptr + row_idx), accs[i]);
|
420 |
+
}
|
421 |
+
}
|
422 |
+
}
|
423 |
+
}
|
424 |
+
|
425 |
+
// Grid: (num_heads, num_seqs, 1).
|
426 |
+
template<
|
427 |
+
typename scalar_t,
|
428 |
+
typename cache_t,
|
429 |
+
int HEAD_SIZE,
|
430 |
+
int BLOCK_SIZE,
|
431 |
+
int NUM_THREADS,
|
432 |
+
bool IS_FP8_E5M2_KV_CACHE>
|
433 |
+
__global__ void paged_attention_v1_kernel(
|
434 |
+
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
435 |
+
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
436 |
+
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
437 |
+
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
438 |
+
const int num_kv_heads, // [num_heads]
|
439 |
+
const float scale,
|
440 |
+
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
441 |
+
const int* __restrict__ context_lens, // [num_seqs]
|
442 |
+
const int max_num_blocks_per_seq,
|
443 |
+
const float* __restrict__ alibi_slopes, // [num_heads]
|
444 |
+
const int q_stride,
|
445 |
+
const int kv_block_stride,
|
446 |
+
const int kv_head_stride) {
|
447 |
+
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
|
448 |
+
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
449 |
+
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
450 |
+
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
451 |
+
}
|
452 |
+
|
453 |
+
// Grid: (num_heads, num_seqs, max_num_partitions).
|
454 |
+
template<
|
455 |
+
typename scalar_t,
|
456 |
+
typename cache_t,
|
457 |
+
int HEAD_SIZE,
|
458 |
+
int BLOCK_SIZE,
|
459 |
+
int NUM_THREADS,
|
460 |
+
bool IS_FP8_E5M2_KV_CACHE,
|
461 |
+
int PARTITION_SIZE>
|
462 |
+
__global__ void paged_attention_v2_kernel(
|
463 |
+
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
464 |
+
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
465 |
+
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
466 |
+
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
467 |
+
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
468 |
+
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
469 |
+
const int num_kv_heads, // [num_heads]
|
470 |
+
const float scale,
|
471 |
+
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
472 |
+
const int* __restrict__ context_lens, // [num_seqs]
|
473 |
+
const int max_num_blocks_per_seq,
|
474 |
+
const float* __restrict__ alibi_slopes, // [num_heads]
|
475 |
+
const int q_stride,
|
476 |
+
const int kv_block_stride,
|
477 |
+
const int kv_head_stride) {
|
478 |
+
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
|
479 |
+
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
480 |
+
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
481 |
+
q_stride, kv_block_stride, kv_head_stride);
|
482 |
+
}
|
483 |
+
|
484 |
+
// Grid: (num_heads, num_seqs).
|
485 |
+
template<
|
486 |
+
typename scalar_t,
|
487 |
+
int HEAD_SIZE,
|
488 |
+
int NUM_THREADS,
|
489 |
+
int PARTITION_SIZE>
|
490 |
+
__global__ void paged_attention_v2_reduce_kernel(
|
491 |
+
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
492 |
+
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
493 |
+
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
494 |
+
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
495 |
+
const int* __restrict__ context_lens, // [num_seqs]
|
496 |
+
const int max_num_partitions) {
|
497 |
+
const int num_heads = gridDim.x;
|
498 |
+
const int head_idx = blockIdx.x;
|
499 |
+
const int seq_idx = blockIdx.y;
|
500 |
+
const int context_len = context_lens[seq_idx];
|
501 |
+
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
502 |
+
if (num_partitions == 1) {
|
503 |
+
// No need to reduce. Only copy tmp_out to out.
|
504 |
+
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
505 |
+
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
506 |
+
+ head_idx * max_num_partitions * HEAD_SIZE;
|
507 |
+
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
508 |
+
out_ptr[i] = tmp_out_ptr[i];
|
509 |
+
}
|
510 |
+
// Terminate the thread block.
|
511 |
+
return;
|
512 |
+
}
|
513 |
+
|
514 |
+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
515 |
+
const int warp_idx = threadIdx.x / WARP_SIZE;
|
516 |
+
const int lane = threadIdx.x % WARP_SIZE;
|
517 |
+
|
518 |
+
// Size: 2 * num_partitions.
|
519 |
+
extern __shared__ char shared_mem[];
|
520 |
+
// Workspace for reduction.
|
521 |
+
__shared__ float red_smem[2 * NUM_WARPS];
|
522 |
+
|
523 |
+
// Load max logits to shared memory.
|
524 |
+
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
525 |
+
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
526 |
+
+ head_idx * max_num_partitions;
|
527 |
+
float max_logit = -FLT_MAX;
|
528 |
+
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
529 |
+
const float l = max_logits_ptr[i];
|
530 |
+
shared_max_logits[i] = l;
|
531 |
+
max_logit = fmaxf(max_logit, l);
|
532 |
+
}
|
533 |
+
__syncthreads();
|
534 |
+
|
535 |
+
// Get the global max logit.
|
536 |
+
// Reduce within the warp.
|
537 |
+
#pragma unroll
|
538 |
+
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
539 |
+
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
540 |
+
}
|
541 |
+
if (lane == 0) {
|
542 |
+
red_smem[warp_idx] = max_logit;
|
543 |
+
}
|
544 |
+
__syncthreads();
|
545 |
+
// Reduce across warps.
|
546 |
+
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
547 |
+
#pragma unroll
|
548 |
+
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
549 |
+
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
550 |
+
}
|
551 |
+
// Broadcast the max value to all threads.
|
552 |
+
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
553 |
+
|
554 |
+
// Load rescaled exp sums to shared memory.
|
555 |
+
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
556 |
+
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
557 |
+
+ head_idx * max_num_partitions;
|
558 |
+
float global_exp_sum = 0.0f;
|
559 |
+
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
560 |
+
float l = shared_max_logits[i];
|
561 |
+
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
|
562 |
+
global_exp_sum += rescaled_exp_sum;
|
563 |
+
shared_exp_sums[i] = rescaled_exp_sum;
|
564 |
+
}
|
565 |
+
__syncthreads();
|
566 |
+
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
|
567 |
+
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
568 |
+
|
569 |
+
// Aggregate tmp_out to out.
|
570 |
+
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
571 |
+
+ head_idx * max_num_partitions * HEAD_SIZE;
|
572 |
+
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
573 |
+
#pragma unroll
|
574 |
+
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
575 |
+
float acc = 0.0f;
|
576 |
+
for (int j = 0; j < num_partitions; ++j) {
|
577 |
+
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
|
578 |
+
}
|
579 |
+
from_float(out_ptr[i], acc);
|
580 |
+
}
|
581 |
+
}
|
582 |
+
|
583 |
+
} // namespace vllm
|
584 |
+
|
585 |
+
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
586 |
+
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
587 |
+
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
588 |
+
IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
|
589 |
+
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
590 |
+
IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
591 |
+
out_ptr, \
|
592 |
+
query_ptr, \
|
593 |
+
key_cache_ptr, \
|
594 |
+
value_cache_ptr, \
|
595 |
+
num_kv_heads, \
|
596 |
+
scale, \
|
597 |
+
block_tables_ptr, \
|
598 |
+
context_lens_ptr, \
|
599 |
+
max_num_blocks_per_seq, \
|
600 |
+
alibi_slopes_ptr, \
|
601 |
+
q_stride, \
|
602 |
+
kv_block_stride, \
|
603 |
+
kv_head_stride);
|
604 |
+
|
605 |
+
// TODO(woosuk): Tune NUM_THREADS.
|
606 |
+
template<
|
607 |
+
typename T,
|
608 |
+
typename CACHE_T,
|
609 |
+
int BLOCK_SIZE,
|
610 |
+
bool IS_FP8_E5M2_KV_CACHE,
|
611 |
+
int NUM_THREADS = 128>
|
612 |
+
void paged_attention_v1_launcher(
|
613 |
+
torch::Tensor& out,
|
614 |
+
torch::Tensor& query,
|
615 |
+
torch::Tensor& key_cache,
|
616 |
+
torch::Tensor& value_cache,
|
617 |
+
int num_kv_heads,
|
618 |
+
float scale,
|
619 |
+
torch::Tensor& block_tables,
|
620 |
+
torch::Tensor& context_lens,
|
621 |
+
int max_context_len,
|
622 |
+
const c10::optional<torch::Tensor>& alibi_slopes) {
|
623 |
+
int num_seqs = query.size(0);
|
624 |
+
int num_heads = query.size(1);
|
625 |
+
int head_size = query.size(2);
|
626 |
+
int max_num_blocks_per_seq = block_tables.size(1);
|
627 |
+
int q_stride = query.stride(0);
|
628 |
+
int kv_block_stride = key_cache.stride(0);
|
629 |
+
int kv_head_stride = key_cache.stride(1);
|
630 |
+
|
631 |
+
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
632 |
+
assert(head_size % thread_group_size == 0);
|
633 |
+
|
634 |
+
// NOTE: alibi_slopes is optional.
|
635 |
+
const float* alibi_slopes_ptr = alibi_slopes ?
|
636 |
+
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
637 |
+
: nullptr;
|
638 |
+
|
639 |
+
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
640 |
+
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
641 |
+
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
642 |
+
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
643 |
+
int* block_tables_ptr = block_tables.data_ptr<int>();
|
644 |
+
int* context_lens_ptr = context_lens.data_ptr<int>();
|
645 |
+
|
646 |
+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
647 |
+
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
|
648 |
+
int logits_size = padded_max_context_len * sizeof(float);
|
649 |
+
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
650 |
+
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
651 |
+
// Keep that in sync with the logic here!
|
652 |
+
int shared_mem_size = std::max(logits_size, outputs_size);
|
653 |
+
|
654 |
+
dim3 grid(num_heads, num_seqs, 1);
|
655 |
+
dim3 block(NUM_THREADS);
|
656 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
657 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
658 |
+
switch (head_size) {
|
659 |
+
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
660 |
+
// head sizes that we use in the model. However, we can easily extend this
|
661 |
+
// to support any head size which is a multiple of 16.
|
662 |
+
case 64:
|
663 |
+
LAUNCH_PAGED_ATTENTION_V1(64);
|
664 |
+
break;
|
665 |
+
case 80:
|
666 |
+
LAUNCH_PAGED_ATTENTION_V1(80);
|
667 |
+
break;
|
668 |
+
case 96:
|
669 |
+
LAUNCH_PAGED_ATTENTION_V1(96);
|
670 |
+
break;
|
671 |
+
case 112:
|
672 |
+
LAUNCH_PAGED_ATTENTION_V1(112);
|
673 |
+
break;
|
674 |
+
case 128:
|
675 |
+
LAUNCH_PAGED_ATTENTION_V1(128);
|
676 |
+
break;
|
677 |
+
case 256:
|
678 |
+
LAUNCH_PAGED_ATTENTION_V1(256);
|
679 |
+
break;
|
680 |
+
default:
|
681 |
+
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
682 |
+
break;
|
683 |
+
}
|
684 |
+
}
|
685 |
+
|
686 |
+
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
687 |
+
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
688 |
+
out, \
|
689 |
+
query, \
|
690 |
+
key_cache, \
|
691 |
+
value_cache, \
|
692 |
+
num_kv_heads, \
|
693 |
+
scale, \
|
694 |
+
block_tables, \
|
695 |
+
context_lens, \
|
696 |
+
max_context_len, \
|
697 |
+
alibi_slopes);
|
698 |
+
|
699 |
+
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
700 |
+
// 1, 2, 4, 64, 128, 256.
|
701 |
+
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
702 |
+
switch (block_size) { \
|
703 |
+
case 8: \
|
704 |
+
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
705 |
+
break; \
|
706 |
+
case 16: \
|
707 |
+
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
708 |
+
break; \
|
709 |
+
case 32: \
|
710 |
+
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
711 |
+
break; \
|
712 |
+
default: \
|
713 |
+
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
714 |
+
break; \
|
715 |
+
}
|
716 |
+
|
717 |
+
void paged_attention_v1(
|
718 |
+
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
719 |
+
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
720 |
+
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
721 |
+
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
722 |
+
int num_kv_heads, // [num_heads]
|
723 |
+
float scale,
|
724 |
+
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
725 |
+
torch::Tensor& context_lens, // [num_seqs]
|
726 |
+
int block_size,
|
727 |
+
int max_context_len,
|
728 |
+
const c10::optional<torch::Tensor>& alibi_slopes,
|
729 |
+
const std::string& kv_cache_dtype) {
|
730 |
+
if (kv_cache_dtype == "auto") {
|
731 |
+
if (query.dtype() == at::ScalarType::Float) {
|
732 |
+
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
733 |
+
} else if (query.dtype() == at::ScalarType::Half) {
|
734 |
+
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
735 |
+
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
736 |
+
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
737 |
+
} else {
|
738 |
+
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
739 |
+
}
|
740 |
+
} else if (kv_cache_dtype == "fp8_e5m2") {
|
741 |
+
if (query.dtype() == at::ScalarType::Float) {
|
742 |
+
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
743 |
+
} else if (query.dtype() == at::ScalarType::Half) {
|
744 |
+
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
745 |
+
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
746 |
+
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
747 |
+
} else {
|
748 |
+
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
749 |
+
}
|
750 |
+
} else {
|
751 |
+
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
752 |
+
}
|
753 |
+
}
|
754 |
+
|
755 |
+
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
756 |
+
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
757 |
+
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
|
758 |
+
<<<grid, block, shared_mem_size, stream>>>( \
|
759 |
+
exp_sums_ptr, \
|
760 |
+
max_logits_ptr, \
|
761 |
+
tmp_out_ptr, \
|
762 |
+
query_ptr, \
|
763 |
+
key_cache_ptr, \
|
764 |
+
value_cache_ptr, \
|
765 |
+
num_kv_heads, \
|
766 |
+
scale, \
|
767 |
+
block_tables_ptr, \
|
768 |
+
context_lens_ptr, \
|
769 |
+
max_num_blocks_per_seq, \
|
770 |
+
alibi_slopes_ptr, \
|
771 |
+
q_stride, \
|
772 |
+
kv_block_stride, \
|
773 |
+
kv_head_stride); \
|
774 |
+
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
775 |
+
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
776 |
+
out_ptr, \
|
777 |
+
exp_sums_ptr, \
|
778 |
+
max_logits_ptr, \
|
779 |
+
tmp_out_ptr, \
|
780 |
+
context_lens_ptr, \
|
781 |
+
max_num_partitions);
|
782 |
+
|
783 |
+
template<
|
784 |
+
typename T,
|
785 |
+
typename CACHE_T,
|
786 |
+
int BLOCK_SIZE,
|
787 |
+
bool IS_FP8_E5M2_KV_CACHE,
|
788 |
+
int NUM_THREADS = 128,
|
789 |
+
int PARTITION_SIZE = 512>
|
790 |
+
void paged_attention_v2_launcher(
|
791 |
+
torch::Tensor& out,
|
792 |
+
torch::Tensor& exp_sums,
|
793 |
+
torch::Tensor& max_logits,
|
794 |
+
torch::Tensor& tmp_out,
|
795 |
+
torch::Tensor& query,
|
796 |
+
torch::Tensor& key_cache,
|
797 |
+
torch::Tensor& value_cache,
|
798 |
+
int num_kv_heads,
|
799 |
+
float scale,
|
800 |
+
torch::Tensor& block_tables,
|
801 |
+
torch::Tensor& context_lens,
|
802 |
+
int max_context_len,
|
803 |
+
const c10::optional<torch::Tensor>& alibi_slopes) {
|
804 |
+
int num_seqs = query.size(0);
|
805 |
+
int num_heads = query.size(1);
|
806 |
+
int head_size = query.size(2);
|
807 |
+
int max_num_blocks_per_seq = block_tables.size(1);
|
808 |
+
int q_stride = query.stride(0);
|
809 |
+
int kv_block_stride = key_cache.stride(0);
|
810 |
+
int kv_head_stride = key_cache.stride(1);
|
811 |
+
|
812 |
+
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
813 |
+
assert(head_size % thread_group_size == 0);
|
814 |
+
|
815 |
+
// NOTE: alibi_slopes is optional.
|
816 |
+
const float* alibi_slopes_ptr = alibi_slopes ?
|
817 |
+
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
818 |
+
: nullptr;
|
819 |
+
|
820 |
+
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
821 |
+
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
822 |
+
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
823 |
+
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
824 |
+
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
825 |
+
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
826 |
+
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
827 |
+
int* block_tables_ptr = block_tables.data_ptr<int>();
|
828 |
+
int* context_lens_ptr = context_lens.data_ptr<int>();
|
829 |
+
|
830 |
+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
831 |
+
int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
|
832 |
+
int logits_size = PARTITION_SIZE * sizeof(float);
|
833 |
+
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
834 |
+
|
835 |
+
// For paged attention v2 kernel.
|
836 |
+
dim3 grid(num_heads, num_seqs, max_num_partitions);
|
837 |
+
int shared_mem_size = std::max(logits_size, outputs_size);
|
838 |
+
// For paged attention v2 reduce kernel.
|
839 |
+
dim3 reduce_grid(num_heads, num_seqs);
|
840 |
+
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
841 |
+
|
842 |
+
dim3 block(NUM_THREADS);
|
843 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
844 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
845 |
+
switch (head_size) {
|
846 |
+
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
847 |
+
// head sizes that we use in the model. However, we can easily extend this
|
848 |
+
// to support any head size which is a multiple of 16.
|
849 |
+
case 64:
|
850 |
+
LAUNCH_PAGED_ATTENTION_V2(64);
|
851 |
+
break;
|
852 |
+
case 80:
|
853 |
+
LAUNCH_PAGED_ATTENTION_V2(80);
|
854 |
+
break;
|
855 |
+
case 96:
|
856 |
+
LAUNCH_PAGED_ATTENTION_V2(96);
|
857 |
+
break;
|
858 |
+
case 112:
|
859 |
+
LAUNCH_PAGED_ATTENTION_V2(112);
|
860 |
+
break;
|
861 |
+
case 128:
|
862 |
+
LAUNCH_PAGED_ATTENTION_V2(128);
|
863 |
+
break;
|
864 |
+
case 256:
|
865 |
+
LAUNCH_PAGED_ATTENTION_V2(256);
|
866 |
+
break;
|
867 |
+
default:
|
868 |
+
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
869 |
+
break;
|
870 |
+
}
|
871 |
+
}
|
872 |
+
|
873 |
+
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
874 |
+
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
875 |
+
out, \
|
876 |
+
exp_sums, \
|
877 |
+
max_logits, \
|
878 |
+
tmp_out, \
|
879 |
+
query, \
|
880 |
+
key_cache, \
|
881 |
+
value_cache, \
|
882 |
+
num_kv_heads, \
|
883 |
+
scale, \
|
884 |
+
block_tables, \
|
885 |
+
context_lens, \
|
886 |
+
max_context_len, \
|
887 |
+
alibi_slopes);
|
888 |
+
|
889 |
+
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
890 |
+
// 1, 2, 4, 64, 128, 256.
|
891 |
+
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
892 |
+
switch (block_size) { \
|
893 |
+
case 8: \
|
894 |
+
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
895 |
+
break; \
|
896 |
+
case 16: \
|
897 |
+
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
898 |
+
break; \
|
899 |
+
case 32: \
|
900 |
+
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
901 |
+
break; \
|
902 |
+
default: \
|
903 |
+
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
904 |
+
break; \
|
905 |
+
}
|
906 |
+
|
907 |
+
void paged_attention_v2(
|
908 |
+
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
909 |
+
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
910 |
+
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
911 |
+
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
912 |
+
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
913 |
+
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
914 |
+
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
915 |
+
int num_kv_heads, // [num_heads]
|
916 |
+
float scale,
|
917 |
+
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
918 |
+
torch::Tensor& context_lens, // [num_seqs]
|
919 |
+
int block_size,
|
920 |
+
int max_context_len,
|
921 |
+
const c10::optional<torch::Tensor>& alibi_slopes,
|
922 |
+
const std::string& kv_cache_dtype) {
|
923 |
+
if (kv_cache_dtype == "auto") {
|
924 |
+
if (query.dtype() == at::ScalarType::Float) {
|
925 |
+
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
926 |
+
} else if (query.dtype() == at::ScalarType::Half) {
|
927 |
+
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
928 |
+
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
929 |
+
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
930 |
+
} else {
|
931 |
+
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
932 |
+
}
|
933 |
+
} else if (kv_cache_dtype == "fp8_e5m2") {
|
934 |
+
if (query.dtype() == at::ScalarType::Float) {
|
935 |
+
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
936 |
+
} else if (query.dtype() == at::ScalarType::Half) {
|
937 |
+
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
938 |
+
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
939 |
+
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
940 |
+
} else {
|
941 |
+
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
942 |
+
}
|
943 |
+
} else {
|
944 |
+
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
945 |
+
}
|
946 |
+
}
|
947 |
+
|
948 |
+
#undef WARP_SIZE
|
949 |
+
#undef MAX
|
950 |
+
#undef MIN
|
951 |
+
#undef DIVIDE_ROUND_UP
|
csrc/attention/attention_utils.cuh
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
3 |
+
* Copyright (c) 2023, The vLLM team.
|
4 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
#pragma once
|
19 |
+
|
20 |
+
#include "../cuda_compat.h"
|
21 |
+
#include "attention_dtypes.h"
|
22 |
+
|
23 |
+
#include <float.h>
|
24 |
+
#include <type_traits>
|
25 |
+
|
26 |
+
namespace vllm {
|
27 |
+
|
28 |
+
// Q*K^T operation.
|
29 |
+
template<int THREAD_GROUP_SIZE, typename Vec, int N>
|
30 |
+
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
31 |
+
using A_vec = typename FloatVec<Vec>::Type;
|
32 |
+
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
33 |
+
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
|
34 |
+
#pragma unroll
|
35 |
+
for (int ii = 1; ii < N; ++ii) {
|
36 |
+
qk_vec = fma(q[ii], k[ii], qk_vec);
|
37 |
+
}
|
38 |
+
|
39 |
+
// Finalize the reduction across lanes.
|
40 |
+
float qk = sum(qk_vec);
|
41 |
+
#pragma unroll
|
42 |
+
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
43 |
+
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
|
44 |
+
}
|
45 |
+
return qk;
|
46 |
+
}
|
47 |
+
|
48 |
+
template<typename T, int THREAD_GROUP_SIZE>
|
49 |
+
struct Qk_dot {
|
50 |
+
template<typename Vec, int N>
|
51 |
+
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
|
52 |
+
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
|
53 |
+
}
|
54 |
+
};
|
55 |
+
|
56 |
+
} // namespace vllm
|
csrc/attention/dtype_bfloat16.cuh
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
3 |
+
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
4 |
+
* Copyright (c) 2023, The vLLM team.
|
5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
6 |
+
*
|
7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
* you may not use this file except in compliance with the License.
|
9 |
+
* You may obtain a copy of the License at
|
10 |
+
*
|
11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
*
|
13 |
+
* Unless required by applicable law or agreed to in writing, software
|
14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
* See the License for the specific language governing permissions and
|
17 |
+
* limitations under the License.
|
18 |
+
*/
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#include "attention_generic.cuh"
|
22 |
+
#include "dtype_float32.cuh"
|
23 |
+
|
24 |
+
#ifndef USE_ROCM
|
25 |
+
#include <cuda_bf16.h>
|
26 |
+
#include <cuda_fp16.h>
|
27 |
+
#else
|
28 |
+
#include <hip/hip_bf16.h>
|
29 |
+
#include <hip/hip_fp16.h>
|
30 |
+
|
31 |
+
typedef __hip_bfloat162 __nv_bfloat162;
|
32 |
+
typedef __hip_bfloat16 __nv_bfloat16;
|
33 |
+
#endif
|
34 |
+
|
35 |
+
#include <stdint.h>
|
36 |
+
|
37 |
+
namespace vllm {
|
38 |
+
|
39 |
+
// Define custom BF16 vector data types.
|
40 |
+
struct bf16_4_t {
|
41 |
+
__nv_bfloat162 x;
|
42 |
+
__nv_bfloat162 y;
|
43 |
+
};
|
44 |
+
|
45 |
+
struct bf16_8_t {
|
46 |
+
__nv_bfloat162 x;
|
47 |
+
__nv_bfloat162 y;
|
48 |
+
__nv_bfloat162 z;
|
49 |
+
__nv_bfloat162 w;
|
50 |
+
};
|
51 |
+
|
52 |
+
// BF16 vector types for Q, K, V.
|
53 |
+
template<>
|
54 |
+
struct Vec<__nv_bfloat16, 1> {
|
55 |
+
using Type = __nv_bfloat16;
|
56 |
+
};
|
57 |
+
template<>
|
58 |
+
struct Vec<__nv_bfloat16, 2> {
|
59 |
+
using Type = __nv_bfloat162;
|
60 |
+
};
|
61 |
+
template<>
|
62 |
+
struct Vec<__nv_bfloat16, 4> {
|
63 |
+
using Type = bf16_4_t;
|
64 |
+
};
|
65 |
+
template<>
|
66 |
+
struct Vec<__nv_bfloat16, 8> {
|
67 |
+
using Type = bf16_8_t;
|
68 |
+
};
|
69 |
+
|
70 |
+
// FP32 accumulator vector types corresponding to Vec.
|
71 |
+
template<>
|
72 |
+
struct FloatVec<__nv_bfloat16> {
|
73 |
+
using Type = float;
|
74 |
+
};
|
75 |
+
template<>
|
76 |
+
struct FloatVec<__nv_bfloat162> {
|
77 |
+
using Type = float2;
|
78 |
+
};
|
79 |
+
template<>
|
80 |
+
struct FloatVec<bf16_4_t> {
|
81 |
+
using Type = Float4_;
|
82 |
+
};
|
83 |
+
template<>
|
84 |
+
struct FloatVec<bf16_8_t> {
|
85 |
+
using Type = Float8_;
|
86 |
+
};
|
87 |
+
|
88 |
+
// Utility functions for type conversions.
|
89 |
+
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
|
90 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
91 |
+
assert(false);
|
92 |
+
#else
|
93 |
+
return __bfloat1622float2(val);
|
94 |
+
#endif
|
95 |
+
}
|
96 |
+
|
97 |
+
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
98 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
99 |
+
assert(false);
|
100 |
+
#else
|
101 |
+
return __bfloat162bfloat162(val);
|
102 |
+
#endif
|
103 |
+
}
|
104 |
+
|
105 |
+
// Vector addition.
|
106 |
+
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
107 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
108 |
+
assert(false);
|
109 |
+
#else
|
110 |
+
#ifndef USE_ROCM
|
111 |
+
return a + b;
|
112 |
+
#else
|
113 |
+
return __hadd(a, b);
|
114 |
+
#endif
|
115 |
+
#endif
|
116 |
+
}
|
117 |
+
|
118 |
+
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
119 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
120 |
+
assert(false);
|
121 |
+
#else
|
122 |
+
return __hadd2(a, b);
|
123 |
+
#endif
|
124 |
+
}
|
125 |
+
|
126 |
+
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
|
127 |
+
bf16_4_t c;
|
128 |
+
c.x = add(a.x, b.x);
|
129 |
+
c.y = add(a.y, b.y);
|
130 |
+
return c;
|
131 |
+
}
|
132 |
+
|
133 |
+
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
|
134 |
+
bf16_8_t c;
|
135 |
+
c.x = add(a.x, b.x);
|
136 |
+
c.y = add(a.y, b.y);
|
137 |
+
c.z = add(a.z, b.z);
|
138 |
+
c.w = add(a.w, b.w);
|
139 |
+
return c;
|
140 |
+
}
|
141 |
+
|
142 |
+
inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
|
143 |
+
float2 fa = bf1622float2(a);
|
144 |
+
return add(fa, fb);
|
145 |
+
}
|
146 |
+
|
147 |
+
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
|
148 |
+
Float4_ fc;
|
149 |
+
fc.x = add(a.x, fb.x);
|
150 |
+
fc.y = add(a.y, fb.y);
|
151 |
+
return fc;
|
152 |
+
}
|
153 |
+
|
154 |
+
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
|
155 |
+
Float8_ fc;
|
156 |
+
fc.x = add(a.x, fb.x);
|
157 |
+
fc.y = add(a.y, fb.y);
|
158 |
+
fc.z = add(a.z, fb.z);
|
159 |
+
fc.w = add(a.w, fb.w);
|
160 |
+
return fc;
|
161 |
+
}
|
162 |
+
|
163 |
+
// Vector multiplication.
|
164 |
+
template<>
|
165 |
+
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
166 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
167 |
+
assert(false);
|
168 |
+
#else
|
169 |
+
return __hmul(a, b);
|
170 |
+
#endif
|
171 |
+
}
|
172 |
+
|
173 |
+
template<>
|
174 |
+
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
175 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
176 |
+
assert(false);
|
177 |
+
#else
|
178 |
+
return __hmul2(a, b);
|
179 |
+
#endif
|
180 |
+
}
|
181 |
+
|
182 |
+
template<>
|
183 |
+
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
184 |
+
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
185 |
+
}
|
186 |
+
|
187 |
+
template<>
|
188 |
+
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
189 |
+
bf16_4_t c;
|
190 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
191 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
192 |
+
return c;
|
193 |
+
}
|
194 |
+
|
195 |
+
template<>
|
196 |
+
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
|
197 |
+
__nv_bfloat162 s = bf162bf162(a);
|
198 |
+
bf16_4_t c;
|
199 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
200 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
201 |
+
return c;
|
202 |
+
}
|
203 |
+
|
204 |
+
template<>
|
205 |
+
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
206 |
+
bf16_8_t c;
|
207 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
208 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
209 |
+
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
210 |
+
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
211 |
+
return c;
|
212 |
+
}
|
213 |
+
|
214 |
+
template<>
|
215 |
+
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
|
216 |
+
__nv_bfloat162 s = bf162bf162(a);
|
217 |
+
bf16_8_t c;
|
218 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
219 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
220 |
+
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
221 |
+
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
222 |
+
return c;
|
223 |
+
}
|
224 |
+
|
225 |
+
template<>
|
226 |
+
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
227 |
+
float fa = __bfloat162float(a);
|
228 |
+
float fb = __bfloat162float(b);
|
229 |
+
return fa * fb;
|
230 |
+
}
|
231 |
+
|
232 |
+
template<>
|
233 |
+
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
234 |
+
float2 fa = bf1622float2(a);
|
235 |
+
float2 fb = bf1622float2(b);
|
236 |
+
return mul<float2, float2, float2>(fa, fb);
|
237 |
+
}
|
238 |
+
|
239 |
+
template<>
|
240 |
+
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
241 |
+
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
242 |
+
}
|
243 |
+
|
244 |
+
template<>
|
245 |
+
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
246 |
+
Float4_ fc;
|
247 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
248 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
249 |
+
return fc;
|
250 |
+
}
|
251 |
+
|
252 |
+
template<>
|
253 |
+
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
|
254 |
+
__nv_bfloat162 s = bf162bf162(a);
|
255 |
+
Float4_ fc;
|
256 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
257 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
258 |
+
return fc;
|
259 |
+
}
|
260 |
+
|
261 |
+
template<>
|
262 |
+
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
263 |
+
Float8_ fc;
|
264 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
265 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
266 |
+
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
267 |
+
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
268 |
+
return fc;
|
269 |
+
}
|
270 |
+
|
271 |
+
template<>
|
272 |
+
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
273 |
+
__nv_bfloat162 s = bf162bf162(a);
|
274 |
+
Float8_ fc;
|
275 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
276 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
277 |
+
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
278 |
+
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
279 |
+
return fc;
|
280 |
+
}
|
281 |
+
|
282 |
+
// Vector fused multiply-add.
|
283 |
+
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
284 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
285 |
+
assert(false);
|
286 |
+
#else
|
287 |
+
return __hfma2(a, b, c);
|
288 |
+
#endif
|
289 |
+
}
|
290 |
+
|
291 |
+
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
292 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
293 |
+
assert(false);
|
294 |
+
#else
|
295 |
+
return __hfma2(bf162bf162(a), b, c);
|
296 |
+
#endif
|
297 |
+
}
|
298 |
+
|
299 |
+
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
|
300 |
+
bf16_4_t d;
|
301 |
+
d.x = fma(a.x, b.x, c.x);
|
302 |
+
d.y = fma(a.y, b.y, c.y);
|
303 |
+
return d;
|
304 |
+
}
|
305 |
+
|
306 |
+
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
|
307 |
+
__nv_bfloat162 s = bf162bf162(a);
|
308 |
+
bf16_4_t d;
|
309 |
+
d.x = fma(s, b.x, c.x);
|
310 |
+
d.y = fma(s, b.y, c.y);
|
311 |
+
return d;
|
312 |
+
}
|
313 |
+
|
314 |
+
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
|
315 |
+
bf16_8_t d;
|
316 |
+
d.x = fma(a.x, b.x, c.x);
|
317 |
+
d.y = fma(a.y, b.y, c.y);
|
318 |
+
d.z = fma(a.z, b.z, c.z);
|
319 |
+
d.w = fma(a.w, b.w, c.w);
|
320 |
+
return d;
|
321 |
+
}
|
322 |
+
|
323 |
+
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
|
324 |
+
__nv_bfloat162 s = bf162bf162(a);
|
325 |
+
bf16_8_t d;
|
326 |
+
d.x = fma(s, b.x, c.x);
|
327 |
+
d.y = fma(s, b.y, c.y);
|
328 |
+
d.z = fma(s, b.z, c.z);
|
329 |
+
d.w = fma(s, b.w, c.w);
|
330 |
+
return d;
|
331 |
+
}
|
332 |
+
|
333 |
+
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
|
334 |
+
return __bfloat162float(a) * __bfloat162float(b) + fc;
|
335 |
+
}
|
336 |
+
|
337 |
+
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
|
338 |
+
float2 fa = bf1622float2(a);
|
339 |
+
float2 fb = bf1622float2(b);
|
340 |
+
return fma(fa, fb, fc);
|
341 |
+
}
|
342 |
+
|
343 |
+
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
|
344 |
+
return fma(bf162bf162(a), b, fc);
|
345 |
+
}
|
346 |
+
|
347 |
+
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
|
348 |
+
Float4_ fd;
|
349 |
+
fd.x = fma(a.x, b.x, fc.x);
|
350 |
+
fd.y = fma(a.y, b.y, fc.y);
|
351 |
+
return fd;
|
352 |
+
}
|
353 |
+
|
354 |
+
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
|
355 |
+
__nv_bfloat162 s = bf162bf162(a);
|
356 |
+
Float4_ fd;
|
357 |
+
fd.x = fma(s, b.x, fc.x);
|
358 |
+
fd.y = fma(s, b.y, fc.y);
|
359 |
+
return fd;
|
360 |
+
}
|
361 |
+
|
362 |
+
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
|
363 |
+
Float8_ fd;
|
364 |
+
fd.x = fma(a.x, b.x, fc.x);
|
365 |
+
fd.y = fma(a.y, b.y, fc.y);
|
366 |
+
fd.z = fma(a.z, b.z, fc.z);
|
367 |
+
fd.w = fma(a.w, b.w, fc.w);
|
368 |
+
return fd;
|
369 |
+
}
|
370 |
+
|
371 |
+
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
|
372 |
+
__nv_bfloat162 s = bf162bf162(a);
|
373 |
+
Float8_ fd;
|
374 |
+
fd.x = fma(s, b.x, fc.x);
|
375 |
+
fd.y = fma(s, b.y, fc.y);
|
376 |
+
fd.z = fma(s, b.z, fc.z);
|
377 |
+
fd.w = fma(s, b.w, fc.w);
|
378 |
+
return fd;
|
379 |
+
}
|
380 |
+
|
381 |
+
// Vector sum.
|
382 |
+
template<>
|
383 |
+
inline __device__ float sum(__nv_bfloat16 v) {
|
384 |
+
return __bfloat162float(v);
|
385 |
+
}
|
386 |
+
|
387 |
+
template<>
|
388 |
+
inline __device__ float sum(__nv_bfloat162 v) {
|
389 |
+
float2 vf = bf1622float2(v);
|
390 |
+
return vf.x + vf.y;
|
391 |
+
}
|
392 |
+
|
393 |
+
template<>
|
394 |
+
inline __device__ float sum(bf16_4_t v) {
|
395 |
+
return sum(v.x) + sum(v.y);
|
396 |
+
}
|
397 |
+
|
398 |
+
template<>
|
399 |
+
inline __device__ float sum(bf16_8_t v) {
|
400 |
+
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
401 |
+
}
|
402 |
+
|
403 |
+
// From float32 to bfloat16.
|
404 |
+
inline __device__ void from_float(__nv_bfloat16& dst, float src) {
|
405 |
+
dst = __float2bfloat16(src);
|
406 |
+
}
|
407 |
+
|
408 |
+
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
|
409 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
410 |
+
assert(false);
|
411 |
+
#else
|
412 |
+
dst = __float22bfloat162_rn(src);
|
413 |
+
#endif
|
414 |
+
}
|
415 |
+
|
416 |
+
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
|
417 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
418 |
+
assert(false);
|
419 |
+
#else
|
420 |
+
dst.x = __float22bfloat162_rn(src.x);
|
421 |
+
dst.y = __float22bfloat162_rn(src.y);
|
422 |
+
#endif
|
423 |
+
}
|
424 |
+
|
425 |
+
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
426 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
427 |
+
assert(false);
|
428 |
+
#else
|
429 |
+
dst.x = __float22bfloat162_rn(src.x);
|
430 |
+
dst.y = __float22bfloat162_rn(src.y);
|
431 |
+
dst.z = __float22bfloat162_rn(src.z);
|
432 |
+
dst.w = __float22bfloat162_rn(src.w);
|
433 |
+
#endif
|
434 |
+
}
|
435 |
+
|
436 |
+
// From bfloat16 to float32.
|
437 |
+
inline __device__ float to_float(__nv_bfloat16 u) {
|
438 |
+
return __bfloat162float(u);
|
439 |
+
}
|
440 |
+
|
441 |
+
// Zero-out a variable.
|
442 |
+
inline __device__ void zero(__nv_bfloat16& dst) {
|
443 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
444 |
+
assert(false);
|
445 |
+
#else
|
446 |
+
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
|
447 |
+
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
|
448 |
+
#endif
|
449 |
+
}
|
450 |
+
|
451 |
+
} // namespace vllm
|
csrc/attention/dtype_float16.cuh
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
3 |
+
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
4 |
+
* Copyright (c) 2023, The vLLM team.
|
5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
6 |
+
*
|
7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
* you may not use this file except in compliance with the License.
|
9 |
+
* You may obtain a copy of the License at
|
10 |
+
*
|
11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
*
|
13 |
+
* Unless required by applicable law or agreed to in writing, software
|
14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
* See the License for the specific language governing permissions and
|
17 |
+
* limitations under the License.
|
18 |
+
*/
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#include "attention_generic.cuh"
|
22 |
+
#include "dtype_float32.cuh"
|
23 |
+
|
24 |
+
#ifdef USE_ROCM
|
25 |
+
#include <hip/hip_fp16.h>
|
26 |
+
#endif
|
27 |
+
|
28 |
+
#include <stdint.h>
|
29 |
+
|
30 |
+
namespace vllm {
|
31 |
+
|
32 |
+
// FP16 vector types for Q, K, V.
|
33 |
+
template<>
|
34 |
+
struct Vec<uint16_t, 1> {
|
35 |
+
using Type = uint16_t;
|
36 |
+
};
|
37 |
+
template<>
|
38 |
+
struct Vec<uint16_t, 2> {
|
39 |
+
using Type = uint32_t;
|
40 |
+
};
|
41 |
+
template<>
|
42 |
+
struct Vec<uint16_t, 4> {
|
43 |
+
using Type = uint2;
|
44 |
+
};
|
45 |
+
template<>
|
46 |
+
struct Vec<uint16_t, 8> {
|
47 |
+
using Type = uint4;
|
48 |
+
};
|
49 |
+
|
50 |
+
// FP32 accumulator vector types corresponding to Vec.
|
51 |
+
template<>
|
52 |
+
struct FloatVec<uint16_t> {
|
53 |
+
using Type = float;
|
54 |
+
};
|
55 |
+
template<>
|
56 |
+
struct FloatVec<uint32_t> {
|
57 |
+
using Type = float2;
|
58 |
+
};
|
59 |
+
template<>
|
60 |
+
struct FloatVec<uint2> {
|
61 |
+
using Type = Float4_;
|
62 |
+
};
|
63 |
+
template<>
|
64 |
+
struct FloatVec<uint4> {
|
65 |
+
using Type = Float8_;
|
66 |
+
};
|
67 |
+
|
68 |
+
// Utility functions for type conversions.
|
69 |
+
inline __device__ uint32_t h0_h0(uint16_t a) {
|
70 |
+
#ifndef USE_ROCM
|
71 |
+
uint32_t b;
|
72 |
+
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
73 |
+
return b;
|
74 |
+
#else
|
75 |
+
union {
|
76 |
+
uint32_t u32;
|
77 |
+
uint16_t u16[2];
|
78 |
+
} tmp;
|
79 |
+
tmp.u16[0] = a;
|
80 |
+
tmp.u16[1] = a;
|
81 |
+
return tmp.u32;
|
82 |
+
#endif
|
83 |
+
}
|
84 |
+
|
85 |
+
inline __device__ float half_to_float(uint16_t h) {
|
86 |
+
float f;
|
87 |
+
#ifndef USE_ROCM
|
88 |
+
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
89 |
+
#else
|
90 |
+
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
|
91 |
+
#endif
|
92 |
+
return f;
|
93 |
+
}
|
94 |
+
|
95 |
+
inline __device__ float2 half2_to_float2(uint32_t v) {
|
96 |
+
#ifndef USE_ROCM
|
97 |
+
uint16_t lo, hi;
|
98 |
+
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
99 |
+
return make_float2(half_to_float(lo), half_to_float(hi));
|
100 |
+
#else
|
101 |
+
union {
|
102 |
+
uint32_t u32;
|
103 |
+
uint16_t u16[2];
|
104 |
+
} tmp;
|
105 |
+
tmp.u32 = v;
|
106 |
+
float2 ret;
|
107 |
+
ret.x = half_to_float(tmp.u16[0]);
|
108 |
+
ret.y = half_to_float(tmp.u16[1]);
|
109 |
+
return ret;
|
110 |
+
#endif
|
111 |
+
}
|
112 |
+
|
113 |
+
inline __device__ uint16_t float_to_half(float f) {
|
114 |
+
union {
|
115 |
+
uint32_t u32;
|
116 |
+
uint16_t u16[2];
|
117 |
+
} tmp;
|
118 |
+
#ifndef USE_ROCM
|
119 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
120 |
+
#else
|
121 |
+
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
|
122 |
+
#endif
|
123 |
+
return tmp.u16[0];
|
124 |
+
}
|
125 |
+
|
126 |
+
inline __device__ uint32_t float2_to_half2(float2 f) {
|
127 |
+
union {
|
128 |
+
uint32_t u32;
|
129 |
+
uint16_t u16[2];
|
130 |
+
} tmp;
|
131 |
+
#ifndef USE_ROCM
|
132 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
133 |
+
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
134 |
+
#else
|
135 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
136 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
137 |
+
#endif
|
138 |
+
#else
|
139 |
+
tmp.u16[0] = float_to_half(f.x);
|
140 |
+
tmp.u16[1] = float_to_half(f.y);
|
141 |
+
#endif
|
142 |
+
return tmp.u32;
|
143 |
+
}
|
144 |
+
|
145 |
+
// Vector addition.
|
146 |
+
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
147 |
+
uint16_t c;
|
148 |
+
#ifndef USE_ROCM
|
149 |
+
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
150 |
+
#else
|
151 |
+
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
152 |
+
#endif
|
153 |
+
return c;
|
154 |
+
}
|
155 |
+
|
156 |
+
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
157 |
+
uint32_t c;
|
158 |
+
#ifndef USE_ROCM
|
159 |
+
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
160 |
+
#else
|
161 |
+
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
162 |
+
#endif
|
163 |
+
return c;
|
164 |
+
}
|
165 |
+
|
166 |
+
inline __device__ uint2 add(uint2 a, uint2 b) {
|
167 |
+
uint2 c;
|
168 |
+
c.x = add(a.x, b.x);
|
169 |
+
c.y = add(a.y, b.y);
|
170 |
+
return c;
|
171 |
+
}
|
172 |
+
|
173 |
+
inline __device__ uint4 add(uint4 a, uint4 b) {
|
174 |
+
uint4 c;
|
175 |
+
c.x = add(a.x, b.x);
|
176 |
+
c.y = add(a.y, b.y);
|
177 |
+
c.z = add(a.z, b.z);
|
178 |
+
c.w = add(a.w, b.w);
|
179 |
+
return c;
|
180 |
+
}
|
181 |
+
|
182 |
+
inline __device__ float2 add(uint32_t a, float2 fb) {
|
183 |
+
float2 fa = half2_to_float2(a);
|
184 |
+
return add(fa, fb);
|
185 |
+
}
|
186 |
+
|
187 |
+
inline __device__ Float4_ add(uint2 a, Float4_ fb) {
|
188 |
+
Float4_ fc;
|
189 |
+
fc.x = add(a.x, fb.x);
|
190 |
+
fc.y = add(a.y, fb.y);
|
191 |
+
return fc;
|
192 |
+
}
|
193 |
+
|
194 |
+
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
195 |
+
Float8_ fc;
|
196 |
+
fc.x = add(a.x, fb.x);
|
197 |
+
fc.y = add(a.y, fb.y);
|
198 |
+
fc.z = add(a.z, fb.z);
|
199 |
+
fc.w = add(a.w, fb.w);
|
200 |
+
return fc;
|
201 |
+
}
|
202 |
+
|
203 |
+
// Vector multiplication.
|
204 |
+
template<>
|
205 |
+
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
206 |
+
uint16_t c;
|
207 |
+
#ifndef USE_ROCM
|
208 |
+
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
209 |
+
#else
|
210 |
+
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
211 |
+
#endif
|
212 |
+
return c;
|
213 |
+
}
|
214 |
+
|
215 |
+
template<>
|
216 |
+
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
217 |
+
uint32_t c;
|
218 |
+
#ifndef USE_ROCM
|
219 |
+
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
220 |
+
#else
|
221 |
+
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
222 |
+
#endif
|
223 |
+
return c;
|
224 |
+
}
|
225 |
+
|
226 |
+
template<>
|
227 |
+
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
|
228 |
+
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
|
229 |
+
}
|
230 |
+
|
231 |
+
template<>
|
232 |
+
inline __device__ uint2 mul(uint2 a, uint2 b) {
|
233 |
+
uint2 c;
|
234 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
235 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
236 |
+
return c;
|
237 |
+
}
|
238 |
+
|
239 |
+
template<>
|
240 |
+
inline __device__ uint2 mul(uint16_t a, uint2 b) {
|
241 |
+
uint32_t s = h0_h0(a);
|
242 |
+
uint2 c;
|
243 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
244 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
245 |
+
return c;
|
246 |
+
}
|
247 |
+
|
248 |
+
template<>
|
249 |
+
inline __device__ uint4 mul(uint4 a, uint4 b) {
|
250 |
+
uint4 c;
|
251 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
252 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
253 |
+
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
|
254 |
+
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
|
255 |
+
return c;
|
256 |
+
}
|
257 |
+
|
258 |
+
template<>
|
259 |
+
inline __device__ uint4 mul(uint16_t a, uint4 b) {
|
260 |
+
uint32_t s = h0_h0(a);
|
261 |
+
uint4 c;
|
262 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
263 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
264 |
+
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
|
265 |
+
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
|
266 |
+
return c;
|
267 |
+
}
|
268 |
+
|
269 |
+
template<>
|
270 |
+
inline __device__ float mul(uint16_t a, uint16_t b) {
|
271 |
+
float fa = half_to_float(a);
|
272 |
+
float fb = half_to_float(b);
|
273 |
+
return fa * fb;
|
274 |
+
}
|
275 |
+
|
276 |
+
template<>
|
277 |
+
inline __device__ float2 mul(uint32_t a, uint32_t b) {
|
278 |
+
float2 fa = half2_to_float2(a);
|
279 |
+
float2 fb = half2_to_float2(b);
|
280 |
+
return mul<float2, float2, float2>(fa, fb);
|
281 |
+
}
|
282 |
+
|
283 |
+
template<>
|
284 |
+
inline __device__ float2 mul(uint16_t a, uint32_t b) {
|
285 |
+
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
|
286 |
+
}
|
287 |
+
|
288 |
+
template<>
|
289 |
+
inline __device__ Float4_ mul(uint2 a, uint2 b) {
|
290 |
+
Float4_ fc;
|
291 |
+
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
292 |
+
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
293 |
+
return fc;
|
294 |
+
}
|
295 |
+
|
296 |
+
template<>
|
297 |
+
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
|
298 |
+
uint32_t s = h0_h0(a);
|
299 |
+
Float4_ fc;
|
300 |
+
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
301 |
+
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
302 |
+
return fc;
|
303 |
+
}
|
304 |
+
|
305 |
+
template<>
|
306 |
+
inline __device__ Float8_ mul(uint4 a, uint4 b) {
|
307 |
+
Float8_ fc;
|
308 |
+
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
309 |
+
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
310 |
+
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
|
311 |
+
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
|
312 |
+
return fc;
|
313 |
+
}
|
314 |
+
|
315 |
+
template<>
|
316 |
+
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
317 |
+
uint32_t s = h0_h0(a);
|
318 |
+
Float8_ fc;
|
319 |
+
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
320 |
+
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
321 |
+
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
|
322 |
+
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
|
323 |
+
return fc;
|
324 |
+
}
|
325 |
+
|
326 |
+
// Vector fused multiply-add.
|
327 |
+
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
328 |
+
uint32_t d;
|
329 |
+
#ifndef USE_ROCM
|
330 |
+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
331 |
+
#else
|
332 |
+
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
333 |
+
#endif
|
334 |
+
return d;
|
335 |
+
}
|
336 |
+
|
337 |
+
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
|
338 |
+
return fma(h0_h0(a), b, c);
|
339 |
+
}
|
340 |
+
|
341 |
+
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
|
342 |
+
uint2 d;
|
343 |
+
d.x = fma(a.x, b.x, c.x);
|
344 |
+
d.y = fma(a.y, b.y, c.y);
|
345 |
+
return d;
|
346 |
+
}
|
347 |
+
|
348 |
+
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
|
349 |
+
uint32_t s = h0_h0(a);
|
350 |
+
uint2 d;
|
351 |
+
d.x = fma(s, b.x, c.x);
|
352 |
+
d.y = fma(s, b.y, c.y);
|
353 |
+
return d;
|
354 |
+
}
|
355 |
+
|
356 |
+
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
|
357 |
+
uint4 d;
|
358 |
+
d.x = fma(a.x, b.x, c.x);
|
359 |
+
d.y = fma(a.y, b.y, c.y);
|
360 |
+
d.z = fma(a.z, b.z, c.z);
|
361 |
+
d.w = fma(a.w, b.w, c.w);
|
362 |
+
return d;
|
363 |
+
}
|
364 |
+
|
365 |
+
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
|
366 |
+
uint32_t s = h0_h0(a);
|
367 |
+
uint4 d;
|
368 |
+
d.x = fma(s, b.x, c.x);
|
369 |
+
d.y = fma(s, b.y, c.y);
|
370 |
+
d.z = fma(s, b.z, c.z);
|
371 |
+
d.w = fma(s, b.w, c.w);
|
372 |
+
return d;
|
373 |
+
}
|
374 |
+
|
375 |
+
inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
|
376 |
+
float fa = half_to_float(a);
|
377 |
+
float fb = half_to_float(b);
|
378 |
+
return fa * fb + fc;
|
379 |
+
}
|
380 |
+
|
381 |
+
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
|
382 |
+
float2 fa = half2_to_float2(a);
|
383 |
+
float2 fb = half2_to_float2(b);
|
384 |
+
return fma(fa, fb, fc);
|
385 |
+
}
|
386 |
+
|
387 |
+
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
|
388 |
+
return fma(h0_h0(a), b, fc);
|
389 |
+
}
|
390 |
+
|
391 |
+
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
|
392 |
+
Float4_ fd;
|
393 |
+
fd.x = fma(a.x, b.x, fc.x);
|
394 |
+
fd.y = fma(a.y, b.y, fc.y);
|
395 |
+
return fd;
|
396 |
+
}
|
397 |
+
|
398 |
+
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
|
399 |
+
uint32_t s = h0_h0(a);
|
400 |
+
Float4_ fd;
|
401 |
+
fd.x = fma(s, b.x, fc.x);
|
402 |
+
fd.y = fma(s, b.y, fc.y);
|
403 |
+
return fd;
|
404 |
+
}
|
405 |
+
|
406 |
+
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
|
407 |
+
Float8_ fd;
|
408 |
+
fd.x = fma(a.x, b.x, fc.x);
|
409 |
+
fd.y = fma(a.y, b.y, fc.y);
|
410 |
+
fd.z = fma(a.z, b.z, fc.z);
|
411 |
+
fd.w = fma(a.w, b.w, fc.w);
|
412 |
+
return fd;
|
413 |
+
}
|
414 |
+
|
415 |
+
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
|
416 |
+
uint32_t s = h0_h0(a);
|
417 |
+
Float8_ fd;
|
418 |
+
fd.x = fma(s, b.x, fc.x);
|
419 |
+
fd.y = fma(s, b.y, fc.y);
|
420 |
+
fd.z = fma(s, b.z, fc.z);
|
421 |
+
fd.w = fma(s, b.w, fc.w);
|
422 |
+
return fd;
|
423 |
+
}
|
424 |
+
|
425 |
+
// Vector sum.
|
426 |
+
template<>
|
427 |
+
inline __device__ float sum(uint16_t v) {
|
428 |
+
return half_to_float(v);
|
429 |
+
}
|
430 |
+
|
431 |
+
template<>
|
432 |
+
inline __device__ float sum(uint32_t v) {
|
433 |
+
float2 tmp = half2_to_float2(v);
|
434 |
+
return tmp.x + tmp.y;
|
435 |
+
}
|
436 |
+
|
437 |
+
template<>
|
438 |
+
inline __device__ float sum(uint2 v) {
|
439 |
+
uint32_t c = add(v.x, v.y);
|
440 |
+
return sum(c);
|
441 |
+
}
|
442 |
+
|
443 |
+
template<>
|
444 |
+
inline __device__ float sum(uint4 v) {
|
445 |
+
uint32_t c = add(v.x, v.y);
|
446 |
+
c = add(c, v.z);
|
447 |
+
c = add(c, v.w);
|
448 |
+
return sum(c);
|
449 |
+
}
|
450 |
+
|
451 |
+
// From float32 to float16.
|
452 |
+
inline __device__ void from_float(uint16_t& dst, float src) {
|
453 |
+
dst = float_to_half(src);
|
454 |
+
}
|
455 |
+
|
456 |
+
inline __device__ void from_float(uint32_t& dst, float2 src) {
|
457 |
+
dst = float2_to_half2(src);
|
458 |
+
}
|
459 |
+
|
460 |
+
inline __device__ void from_float(uint2& dst, Float4_ src) {
|
461 |
+
dst.x = float2_to_half2(src.x);
|
462 |
+
dst.y = float2_to_half2(src.y);
|
463 |
+
}
|
464 |
+
|
465 |
+
inline __device__ void from_float(uint4& dst, Float8_ src) {
|
466 |
+
dst.x = float2_to_half2(src.x);
|
467 |
+
dst.y = float2_to_half2(src.y);
|
468 |
+
dst.z = float2_to_half2(src.z);
|
469 |
+
dst.w = float2_to_half2(src.w);
|
470 |
+
}
|
471 |
+
|
472 |
+
// From float16 to float32.
|
473 |
+
inline __device__ float to_float(uint16_t u) {
|
474 |
+
return half_to_float(u);
|
475 |
+
}
|
476 |
+
|
477 |
+
inline __device__ float2 to_float(uint32_t u) {
|
478 |
+
return half2_to_float2(u);
|
479 |
+
}
|
480 |
+
|
481 |
+
inline __device__ Float4_ to_float(uint2 u) {
|
482 |
+
Float4_ tmp;
|
483 |
+
tmp.x = half2_to_float2(u.x);
|
484 |
+
tmp.y = half2_to_float2(u.y);
|
485 |
+
return tmp;
|
486 |
+
}
|
487 |
+
|
488 |
+
inline __device__ Float8_ to_float(uint4 u) {
|
489 |
+
Float8_ tmp;
|
490 |
+
tmp.x = half2_to_float2(u.x);
|
491 |
+
tmp.y = half2_to_float2(u.y);
|
492 |
+
tmp.z = half2_to_float2(u.z);
|
493 |
+
tmp.w = half2_to_float2(u.w);
|
494 |
+
return tmp;
|
495 |
+
}
|
496 |
+
|
497 |
+
// Zero-out a variable.
|
498 |
+
inline __device__ void zero(uint16_t& dst) {
|
499 |
+
dst = uint16_t(0);
|
500 |
+
}
|
501 |
+
|
502 |
+
} // namespace vllm
|
csrc/attention/dtype_float32.cuh
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
3 |
+
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
4 |
+
* Copyright (c) 2023, The vLLM team.
|
5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
6 |
+
*
|
7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
* you may not use this file except in compliance with the License.
|
9 |
+
* You may obtain a copy of the License at
|
10 |
+
*
|
11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
*
|
13 |
+
* Unless required by applicable law or agreed to in writing, software
|
14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
* See the License for the specific language governing permissions and
|
17 |
+
* limitations under the License.
|
18 |
+
*/
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#include "attention_generic.cuh"
|
22 |
+
|
23 |
+
#include <stdint.h>
|
24 |
+
|
25 |
+
namespace vllm {
|
26 |
+
|
27 |
+
// Define custom FP32 vector data types.
|
28 |
+
struct Float4_ {
|
29 |
+
float2 x;
|
30 |
+
float2 y;
|
31 |
+
};
|
32 |
+
|
33 |
+
struct Float8_ {
|
34 |
+
float2 x;
|
35 |
+
float2 y;
|
36 |
+
float2 z;
|
37 |
+
float2 w;
|
38 |
+
};
|
39 |
+
|
40 |
+
// FP32 vector types for Q, K, V.
|
41 |
+
template<>
|
42 |
+
struct Vec<float, 1> {
|
43 |
+
using Type = float;
|
44 |
+
};
|
45 |
+
template<>
|
46 |
+
struct Vec<float, 2> {
|
47 |
+
using Type = float2;
|
48 |
+
};
|
49 |
+
template<>
|
50 |
+
struct Vec<float, 4> {
|
51 |
+
using Type = float4;
|
52 |
+
};
|
53 |
+
|
54 |
+
// FP32 accumulator vector types corresponding to Vec.
|
55 |
+
template<>
|
56 |
+
struct FloatVec<float> {
|
57 |
+
using Type = float;
|
58 |
+
};
|
59 |
+
template<>
|
60 |
+
struct FloatVec<float2> {
|
61 |
+
using Type = float2;
|
62 |
+
};
|
63 |
+
template<>
|
64 |
+
struct FloatVec<float4> {
|
65 |
+
using Type = float4;
|
66 |
+
};
|
67 |
+
|
68 |
+
// Vector addition.
|
69 |
+
inline __device__ float add(float a, float b) {
|
70 |
+
return a + b;
|
71 |
+
}
|
72 |
+
|
73 |
+
inline __device__ float2 add(float2 a, float2 b) {
|
74 |
+
float2 c;
|
75 |
+
c.x = add(a.x, b.x);
|
76 |
+
c.y = add(a.y, b.y);
|
77 |
+
return c;
|
78 |
+
}
|
79 |
+
|
80 |
+
inline __device__ float4 add(float4 a, float4 b) {
|
81 |
+
float4 c;
|
82 |
+
c.x = add(a.x, b.x);
|
83 |
+
c.y = add(a.y, b.y);
|
84 |
+
c.z = add(a.z, b.z);
|
85 |
+
c.w = add(a.w, b.w);
|
86 |
+
return c;
|
87 |
+
}
|
88 |
+
|
89 |
+
// Vector multiplication.
|
90 |
+
template<>
|
91 |
+
inline __device__ float mul<float, float>(float a, float b) {
|
92 |
+
return a * b;
|
93 |
+
}
|
94 |
+
|
95 |
+
template<>
|
96 |
+
inline __device__ float2 mul(float2 a, float2 b) {
|
97 |
+
float2 c;
|
98 |
+
c.x = a.x * b.x;
|
99 |
+
c.y = a.y * b.y;
|
100 |
+
return c;
|
101 |
+
}
|
102 |
+
|
103 |
+
template<>
|
104 |
+
inline __device__ float2 mul(float a, float2 b) {
|
105 |
+
float2 c;
|
106 |
+
c.x = a * b.x;
|
107 |
+
c.y = a * b.y;
|
108 |
+
return c;
|
109 |
+
}
|
110 |
+
|
111 |
+
template<>
|
112 |
+
inline __device__ float4 mul(float4 a, float4 b) {
|
113 |
+
float4 c;
|
114 |
+
c.x = a.x * b.x;
|
115 |
+
c.y = a.y * b.y;
|
116 |
+
c.z = a.z * b.z;
|
117 |
+
c.w = a.w * b.w;
|
118 |
+
return c;
|
119 |
+
}
|
120 |
+
|
121 |
+
template<>
|
122 |
+
inline __device__ float4 mul(float a, float4 b) {
|
123 |
+
float4 c;
|
124 |
+
c.x = a * b.x;
|
125 |
+
c.y = a * b.y;
|
126 |
+
c.z = a * b.z;
|
127 |
+
c.w = a * b.w;
|
128 |
+
return c;
|
129 |
+
}
|
130 |
+
|
131 |
+
// Vector fused multiply-add.
|
132 |
+
inline __device__ float fma(float a, float b, float c) {
|
133 |
+
return a * b + c;
|
134 |
+
}
|
135 |
+
|
136 |
+
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
137 |
+
float2 d;
|
138 |
+
d.x = fma(a.x, b.x, c.x);
|
139 |
+
d.y = fma(a.y, b.y, c.y);
|
140 |
+
return d;
|
141 |
+
}
|
142 |
+
|
143 |
+
inline __device__ float2 fma(float a, float2 b, float2 c) {
|
144 |
+
float2 d;
|
145 |
+
d.x = fma(a, b.x, c.x);
|
146 |
+
d.y = fma(a, b.y, c.y);
|
147 |
+
return d;
|
148 |
+
}
|
149 |
+
|
150 |
+
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
|
151 |
+
float4 d;
|
152 |
+
d.x = fma(a.x, b.x, c.x);
|
153 |
+
d.y = fma(a.y, b.y, c.y);
|
154 |
+
d.z = fma(a.z, b.z, c.z);
|
155 |
+
d.w = fma(a.w, b.w, c.w);
|
156 |
+
return d;
|
157 |
+
}
|
158 |
+
|
159 |
+
inline __device__ float4 fma(float a, float4 b, float4 c) {
|
160 |
+
float4 d;
|
161 |
+
d.x = fma(a, b.x, c.x);
|
162 |
+
d.y = fma(a, b.y, c.y);
|
163 |
+
d.z = fma(a, b.z, c.z);
|
164 |
+
d.w = fma(a, b.w, c.w);
|
165 |
+
return d;
|
166 |
+
}
|
167 |
+
|
168 |
+
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
|
169 |
+
Float4_ d;
|
170 |
+
d.x = fma(a, b.x, c.x);
|
171 |
+
d.y = fma(a, b.y, c.y);
|
172 |
+
return d;
|
173 |
+
}
|
174 |
+
|
175 |
+
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
|
176 |
+
Float8_ d;
|
177 |
+
d.x = fma(a, b.x, c.x);
|
178 |
+
d.y = fma(a, b.y, c.y);
|
179 |
+
d.z = fma(a, b.z, c.z);
|
180 |
+
d.w = fma(a, b.w, c.w);
|
181 |
+
return d;
|
182 |
+
}
|
183 |
+
|
184 |
+
// Vector sum.
|
185 |
+
template<>
|
186 |
+
inline __device__ float sum(float v) {
|
187 |
+
return v;
|
188 |
+
}
|
189 |
+
|
190 |
+
template<>
|
191 |
+
inline __device__ float sum(float2 v) {
|
192 |
+
return v.x + v.y;
|
193 |
+
}
|
194 |
+
|
195 |
+
template<>
|
196 |
+
inline __device__ float sum(float4 v) {
|
197 |
+
return v.x + v.y + v.z + v.w;
|
198 |
+
}
|
199 |
+
|
200 |
+
template<>
|
201 |
+
inline __device__ float sum(Float4_ v) {
|
202 |
+
return v.x.x + v.x.y + v.y.x + v.y.y;
|
203 |
+
}
|
204 |
+
|
205 |
+
template<>
|
206 |
+
inline __device__ float sum(Float8_ v) {
|
207 |
+
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
|
208 |
+
}
|
209 |
+
|
210 |
+
// Vector dot product.
|
211 |
+
inline __device__ float dot(float a, float b) {
|
212 |
+
return a * b;
|
213 |
+
}
|
214 |
+
|
215 |
+
inline __device__ float dot(float2 a, float2 b) {
|
216 |
+
float2 c = mul<float2, float2, float2>(a, b);
|
217 |
+
return c.x + c.y;
|
218 |
+
}
|
219 |
+
|
220 |
+
inline __device__ float dot(Float4_ a, Float4_ b) {
|
221 |
+
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
222 |
+
acc = fma(a.y, b.y, acc);
|
223 |
+
return acc.x + acc.y;
|
224 |
+
}
|
225 |
+
|
226 |
+
inline __device__ float dot(Float8_ a, Float8_ b) {
|
227 |
+
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
228 |
+
acc = fma(a.y, b.y, acc);
|
229 |
+
acc = fma(a.z, b.z, acc);
|
230 |
+
acc = fma(a.w, b.w, acc);
|
231 |
+
return acc.x + acc.y;
|
232 |
+
}
|
233 |
+
|
234 |
+
// From float to float.
|
235 |
+
inline __device__ void from_float(float& dst, float src) {
|
236 |
+
dst = src;
|
237 |
+
}
|
238 |
+
|
239 |
+
inline __device__ void from_float(float2& dst, float2 src) {
|
240 |
+
dst = src;
|
241 |
+
}
|
242 |
+
|
243 |
+
inline __device__ void from_float(float4& dst, float4 src) {
|
244 |
+
dst = src;
|
245 |
+
}
|
246 |
+
|
247 |
+
// From float to float.
|
248 |
+
inline __device__ float to_float(float u) {
|
249 |
+
return u;
|
250 |
+
}
|
251 |
+
|
252 |
+
inline __device__ float2 to_float(float2 u) {
|
253 |
+
return u;
|
254 |
+
}
|
255 |
+
|
256 |
+
inline __device__ float4 to_float(float4 u) {
|
257 |
+
return u;
|
258 |
+
}
|
259 |
+
|
260 |
+
inline __device__ Float4_ to_float(Float4_ u) {
|
261 |
+
return u;
|
262 |
+
}
|
263 |
+
|
264 |
+
inline __device__ Float8_ to_float(Float8_ u) {
|
265 |
+
return u;
|
266 |
+
}
|
267 |
+
|
268 |
+
// Zero-out a variable.
|
269 |
+
inline __device__ void zero(float& dst) {
|
270 |
+
dst = 0.f;
|
271 |
+
}
|
272 |
+
|
273 |
+
} // namespace vllm
|
csrc/attention/dtype_fp8_e5m2.cuh
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "attention_generic.cuh"
|
4 |
+
|
5 |
+
#include <stdint.h>
|
6 |
+
#ifdef ENABLE_FP8_E5M2
|
7 |
+
#include <cuda_fp8.h>
|
8 |
+
#endif
|
9 |
+
|
10 |
+
namespace vllm {
|
11 |
+
#ifdef ENABLE_FP8_E5M2
|
12 |
+
// fp8 vector types for quantization of kv cache
|
13 |
+
|
14 |
+
template<>
|
15 |
+
struct Vec<uint8_t, 1> {
|
16 |
+
using Type = uint8_t;
|
17 |
+
};
|
18 |
+
|
19 |
+
template<>
|
20 |
+
struct Vec<uint8_t, 2> {
|
21 |
+
using Type = uint16_t;
|
22 |
+
};
|
23 |
+
|
24 |
+
template<>
|
25 |
+
struct Vec<uint8_t, 4> {
|
26 |
+
using Type = uint32_t;
|
27 |
+
};
|
28 |
+
|
29 |
+
template<>
|
30 |
+
struct Vec<uint8_t, 8> {
|
31 |
+
using Type = uint2;
|
32 |
+
};
|
33 |
+
#endif // ENABLE_FP8_E5M2
|
34 |
+
|
35 |
+
} // namespace vllm
|
csrc/cache.h
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/extension.h>
|
4 |
+
|
5 |
+
#include <map>
|
6 |
+
#include <vector>
|
7 |
+
|
8 |
+
void swap_blocks(
|
9 |
+
torch::Tensor& src,
|
10 |
+
torch::Tensor& dst,
|
11 |
+
const std::map<int64_t, int64_t>& block_mapping);
|
12 |
+
|
13 |
+
void copy_blocks(
|
14 |
+
std::vector<torch::Tensor>& key_caches,
|
15 |
+
std::vector<torch::Tensor>& value_caches,
|
16 |
+
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
17 |
+
|
18 |
+
void reshape_and_cache(
|
19 |
+
torch::Tensor& key,
|
20 |
+
torch::Tensor& value,
|
21 |
+
torch::Tensor& key_cache,
|
22 |
+
torch::Tensor& value_cache,
|
23 |
+
torch::Tensor& slot_mapping,
|
24 |
+
const std::string& kv_cache_dtype);
|
25 |
+
|
26 |
+
void gather_cached_kv(
|
27 |
+
torch::Tensor& key,
|
28 |
+
torch::Tensor& value,
|
29 |
+
torch::Tensor& key_cache,
|
30 |
+
torch::Tensor& value_cache,
|
31 |
+
torch::Tensor& slot_mapping);
|
32 |
+
|
33 |
+
// Just for unittest
|
34 |
+
void convert_fp8_e5m2(
|
35 |
+
torch::Tensor& src_cache,
|
36 |
+
torch::Tensor& dst_cache);
|
csrc/cache_kernels.cu
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <ATen/cuda/CUDAContext.h>
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
|
5 |
+
#include "cuda_compat.h"
|
6 |
+
#include "dispatch_utils.h"
|
7 |
+
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
8 |
+
|
9 |
+
#include <algorithm>
|
10 |
+
#include <cassert>
|
11 |
+
#include <map>
|
12 |
+
#include <vector>
|
13 |
+
|
14 |
+
void swap_blocks(
|
15 |
+
torch::Tensor& src,
|
16 |
+
torch::Tensor& dst,
|
17 |
+
const std::map<int64_t, int64_t>& block_mapping) {
|
18 |
+
torch::Device src_device = src.device();
|
19 |
+
torch::Device dst_device = dst.device();
|
20 |
+
cudaMemcpyKind memcpy_type;
|
21 |
+
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
22 |
+
TORCH_CHECK(
|
23 |
+
src_device.index() == dst_device.index(),
|
24 |
+
"src and dst must be on the same GPU");
|
25 |
+
memcpy_type = cudaMemcpyDeviceToDevice;
|
26 |
+
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
27 |
+
memcpy_type = cudaMemcpyDeviceToHost;
|
28 |
+
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
29 |
+
memcpy_type = cudaMemcpyHostToDevice;
|
30 |
+
} else {
|
31 |
+
TORCH_CHECK(false, "Invalid device combination");
|
32 |
+
}
|
33 |
+
|
34 |
+
char *src_ptr = static_cast<char*>(src.data_ptr());
|
35 |
+
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
36 |
+
|
37 |
+
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
38 |
+
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
39 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
40 |
+
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
41 |
+
for (const auto& pair : block_mapping) {
|
42 |
+
int64_t src_block_number = pair.first;
|
43 |
+
int64_t dst_block_number = pair.second;
|
44 |
+
int64_t src_offset = src_block_number * block_size_in_bytes;
|
45 |
+
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
46 |
+
cudaMemcpyAsync(
|
47 |
+
dst_ptr + dst_offset,
|
48 |
+
src_ptr + src_offset,
|
49 |
+
block_size_in_bytes,
|
50 |
+
memcpy_type,
|
51 |
+
stream);
|
52 |
+
}
|
53 |
+
}
|
54 |
+
|
55 |
+
namespace vllm {
|
56 |
+
|
57 |
+
// Grid: (num_layers, num_pairs)
|
58 |
+
template<typename scalar_t>
|
59 |
+
__global__ void copy_blocks_kernel(
|
60 |
+
int64_t* key_cache_ptrs,
|
61 |
+
int64_t* value_cache_ptrs,
|
62 |
+
const int64_t* __restrict__ block_mapping,
|
63 |
+
const int numel_per_block) {
|
64 |
+
const int layer_idx = blockIdx.x;
|
65 |
+
const int pair_idx = blockIdx.y;
|
66 |
+
|
67 |
+
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
68 |
+
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
69 |
+
int64_t src_block_number = block_mapping[2 * pair_idx];
|
70 |
+
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
71 |
+
|
72 |
+
const int64_t src_block_offset = src_block_number * numel_per_block;
|
73 |
+
const int64_t dst_block_offset = dst_block_number * numel_per_block;
|
74 |
+
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
75 |
+
int64_t src_offset = src_block_offset + i;
|
76 |
+
int64_t dst_offset = dst_block_offset + i;
|
77 |
+
key_cache[dst_offset] = key_cache[src_offset];
|
78 |
+
}
|
79 |
+
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
80 |
+
int64_t src_offset = src_block_offset + i;
|
81 |
+
int64_t dst_offset = dst_block_offset + i;
|
82 |
+
value_cache[dst_offset] = value_cache[src_offset];
|
83 |
+
}
|
84 |
+
}
|
85 |
+
|
86 |
+
} // namespace vllm
|
87 |
+
|
88 |
+
void copy_blocks(
|
89 |
+
std::vector<torch::Tensor>& key_caches,
|
90 |
+
std::vector<torch::Tensor>& value_caches,
|
91 |
+
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
|
92 |
+
int num_layers = key_caches.size();
|
93 |
+
TORCH_CHECK(num_layers == value_caches.size());
|
94 |
+
if (num_layers == 0) {
|
95 |
+
return;
|
96 |
+
}
|
97 |
+
torch::Device cache_device = key_caches[0].device();
|
98 |
+
TORCH_CHECK(cache_device.is_cuda());
|
99 |
+
|
100 |
+
// Create data structures for the kernel.
|
101 |
+
// Create an array of pointers to the key and value caches.
|
102 |
+
int64_t key_cache_ptrs[num_layers];
|
103 |
+
int64_t value_cache_ptrs[num_layers];
|
104 |
+
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
105 |
+
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
106 |
+
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
107 |
+
}
|
108 |
+
// Create block mapping array.
|
109 |
+
std::vector<int64_t> block_mapping_vec;
|
110 |
+
for (const auto& pair : block_mapping) {
|
111 |
+
int64_t src_block_number = pair.first;
|
112 |
+
for (int64_t dst_block_number : pair.second) {
|
113 |
+
block_mapping_vec.push_back(src_block_number);
|
114 |
+
block_mapping_vec.push_back(dst_block_number);
|
115 |
+
}
|
116 |
+
}
|
117 |
+
int64_t* block_mapping_array = block_mapping_vec.data();
|
118 |
+
int num_pairs = block_mapping_vec.size() / 2;
|
119 |
+
|
120 |
+
// Move the data structures to the GPU.
|
121 |
+
// NOTE: This synchronizes the CPU and GPU.
|
122 |
+
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
|
123 |
+
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
124 |
+
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
125 |
+
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
126 |
+
torch::Tensor block_mapping_tensor = torch::from_blob(
|
127 |
+
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
|
128 |
+
|
129 |
+
// Launch the kernel.
|
130 |
+
const int numel_per_block = key_caches[0][0].numel();
|
131 |
+
dim3 grid(num_layers, num_pairs);
|
132 |
+
dim3 block(std::min(1024, numel_per_block));
|
133 |
+
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
134 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
135 |
+
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
136 |
+
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
137 |
+
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
138 |
+
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
139 |
+
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
140 |
+
block_mapping_tensor.data_ptr<int64_t>(),
|
141 |
+
numel_per_block);
|
142 |
+
}));
|
143 |
+
}
|
144 |
+
|
145 |
+
namespace vllm {
|
146 |
+
|
147 |
+
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
|
148 |
+
__global__ void reshape_and_cache_kernel(
|
149 |
+
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
150 |
+
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
151 |
+
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
152 |
+
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
153 |
+
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
154 |
+
const int key_stride,
|
155 |
+
const int value_stride,
|
156 |
+
const int num_heads,
|
157 |
+
const int head_size,
|
158 |
+
const int block_size,
|
159 |
+
const int x) {
|
160 |
+
const int64_t token_idx = blockIdx.x;
|
161 |
+
const int64_t slot_idx = slot_mapping[token_idx];
|
162 |
+
if (slot_idx < 0) {
|
163 |
+
// Padding token that should be ignored.
|
164 |
+
return;
|
165 |
+
}
|
166 |
+
|
167 |
+
const int64_t block_idx = slot_idx / block_size;
|
168 |
+
const int64_t block_offset = slot_idx % block_size;
|
169 |
+
|
170 |
+
const int n = num_heads * head_size;
|
171 |
+
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
172 |
+
const int64_t src_key_idx = token_idx * key_stride + i;
|
173 |
+
const int64_t src_value_idx = token_idx * value_stride + i;
|
174 |
+
|
175 |
+
const int head_idx = i / head_size;
|
176 |
+
const int head_offset = i % head_size;
|
177 |
+
const int x_idx = head_offset / x;
|
178 |
+
const int x_offset = head_offset % x;
|
179 |
+
|
180 |
+
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
181 |
+
+ head_idx * (head_size / x) * block_size * x
|
182 |
+
+ x_idx * block_size * x
|
183 |
+
+ block_offset * x
|
184 |
+
+ x_offset;
|
185 |
+
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
|
186 |
+
+ head_idx * head_size * block_size
|
187 |
+
+ head_offset * block_size
|
188 |
+
+ block_offset;
|
189 |
+
scalar_t tgt_key = key[src_key_idx];
|
190 |
+
scalar_t tgt_value = value[src_value_idx];
|
191 |
+
if constexpr (is_fp8_e5m2_kv_cache) {
|
192 |
+
#ifdef ENABLE_FP8_E5M2
|
193 |
+
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
194 |
+
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
195 |
+
#else
|
196 |
+
assert(false);
|
197 |
+
#endif
|
198 |
+
} else {
|
199 |
+
key_cache[tgt_key_idx] = tgt_key;
|
200 |
+
value_cache[tgt_value_idx] = tgt_value;
|
201 |
+
}
|
202 |
+
}
|
203 |
+
}
|
204 |
+
|
205 |
+
} // namespace vllm
|
206 |
+
|
207 |
+
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
208 |
+
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
|
209 |
+
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
210 |
+
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
211 |
+
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
212 |
+
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
213 |
+
slot_mapping.data_ptr<int64_t>(), \
|
214 |
+
key_stride, \
|
215 |
+
value_stride, \
|
216 |
+
num_heads, \
|
217 |
+
head_size, \
|
218 |
+
block_size, \
|
219 |
+
x);
|
220 |
+
|
221 |
+
void reshape_and_cache(
|
222 |
+
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
223 |
+
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
224 |
+
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
225 |
+
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
226 |
+
torch::Tensor& slot_mapping, // [num_tokens]
|
227 |
+
const std::string& kv_cache_dtype)
|
228 |
+
{
|
229 |
+
int num_tokens = key.size(0);
|
230 |
+
int num_heads = key.size(1);
|
231 |
+
int head_size = key.size(2);
|
232 |
+
int block_size = key_cache.size(3);
|
233 |
+
int x = key_cache.size(4);
|
234 |
+
|
235 |
+
int key_stride = key.stride(0);
|
236 |
+
int value_stride = value.stride(0);
|
237 |
+
|
238 |
+
dim3 grid(num_tokens);
|
239 |
+
dim3 block(std::min(num_heads * head_size, 512));
|
240 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
241 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
242 |
+
if (kv_cache_dtype == "auto") {
|
243 |
+
if (key.dtype() == at::ScalarType::Float) {
|
244 |
+
CALL_RESHAPE_AND_CACHE(float, float, false);
|
245 |
+
} else if (key.dtype() == at::ScalarType::Half) {
|
246 |
+
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
247 |
+
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
248 |
+
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
249 |
+
}
|
250 |
+
} else if (kv_cache_dtype == "fp8_e5m2") {
|
251 |
+
if (key.dtype() == at::ScalarType::Float) {
|
252 |
+
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
253 |
+
} else if (key.dtype() == at::ScalarType::Half) {
|
254 |
+
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
255 |
+
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
256 |
+
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
|
257 |
+
}
|
258 |
+
} else {
|
259 |
+
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
260 |
+
}
|
261 |
+
}
|
262 |
+
|
263 |
+
namespace vllm {
|
264 |
+
|
265 |
+
// Grid: (num_blocks, block_size).
|
266 |
+
template<typename scalar_t>
|
267 |
+
__global__ void gather_cached_kv_kernel(
|
268 |
+
scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
|
269 |
+
scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
|
270 |
+
const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
271 |
+
const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
272 |
+
const int* __restrict__ slot_mapping, // [num_tokens]
|
273 |
+
const int key_stride,
|
274 |
+
const int value_stride,
|
275 |
+
const int num_heads,
|
276 |
+
const int head_size,
|
277 |
+
const int block_size,
|
278 |
+
const int x) {
|
279 |
+
const int token_idx = blockIdx.x;
|
280 |
+
const int slot_idx = slot_mapping[token_idx];
|
281 |
+
const int block_idx = slot_idx / block_size;
|
282 |
+
const int block_offset = slot_idx % block_size;
|
283 |
+
|
284 |
+
const int num_tokens = num_heads * head_size;
|
285 |
+
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
|
286 |
+
const int tgt_key_idx = token_idx * key_stride + i;
|
287 |
+
const int tgt_value_idx = token_idx * value_stride + i;
|
288 |
+
|
289 |
+
const int head_idx = i / head_size;
|
290 |
+
const int head_offset = i % head_size;
|
291 |
+
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
|
292 |
+
const int x_offset = head_offset % x;
|
293 |
+
|
294 |
+
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
295 |
+
+ head_idx * (head_size / x) * block_size * x
|
296 |
+
+ x_idx * block_size * x
|
297 |
+
+ block_offset * x
|
298 |
+
+ x_offset;
|
299 |
+
const int src_value_idx = block_idx * num_heads * head_size * block_size
|
300 |
+
+ head_idx * head_size * block_size
|
301 |
+
+ head_offset * block_size
|
302 |
+
+ block_offset;
|
303 |
+
|
304 |
+
key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
|
305 |
+
value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
|
306 |
+
}
|
307 |
+
}
|
308 |
+
|
309 |
+
template <typename scalar_t>
|
310 |
+
__global__ void gather_cached_kv_kernel_optimized(
|
311 |
+
scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
|
312 |
+
scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
|
313 |
+
const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
314 |
+
const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
315 |
+
const int *__restrict__ slot_mapping, // [num_tokens]
|
316 |
+
const int key_stride,
|
317 |
+
const int value_stride,
|
318 |
+
const int num_heads,
|
319 |
+
const int head_size,
|
320 |
+
const int block_size,
|
321 |
+
const int x)
|
322 |
+
{
|
323 |
+
const int token_idx = blockIdx.x;
|
324 |
+
const int slot_idx = slot_mapping[token_idx];
|
325 |
+
const int block_idx = slot_idx / block_size;
|
326 |
+
const int block_offset = slot_idx % block_size;
|
327 |
+
|
328 |
+
const int dim = num_heads * head_size;
|
329 |
+
assert(dim % 4 == 0); // this is true for known use cases
|
330 |
+
const int unroll_factor = 4;
|
331 |
+
const int unrolled_dim = dim / unroll_factor;
|
332 |
+
|
333 |
+
for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
|
334 |
+
{
|
335 |
+
int tgt_key_indices[unroll_factor];
|
336 |
+
int tgt_value_indices[unroll_factor];
|
337 |
+
int src_key_indices[unroll_factor];
|
338 |
+
int src_value_indices[unroll_factor];
|
339 |
+
scalar_t keys_to_store[unroll_factor];
|
340 |
+
scalar_t values_to_store[unroll_factor];
|
341 |
+
|
342 |
+
#pragma unroll
|
343 |
+
for (int j = 0; j < unroll_factor; ++j)
|
344 |
+
{
|
345 |
+
int index = i + j * unrolled_dim;
|
346 |
+
|
347 |
+
const int tgt_key_idx = token_idx * key_stride + index;
|
348 |
+
const int tgt_value_idx = token_idx * value_stride + index;
|
349 |
+
|
350 |
+
const int head_idx = index / head_size;
|
351 |
+
const int head_offset = index % head_size;
|
352 |
+
const int x_idx = head_offset / x;
|
353 |
+
const int x_offset = head_offset % x;
|
354 |
+
|
355 |
+
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
356 |
+
+ head_idx * (head_size / x) * block_size * x
|
357 |
+
+ x_idx * block_size * x
|
358 |
+
+ block_offset * x
|
359 |
+
+ x_offset;
|
360 |
+
const int src_value_idx = block_idx * num_heads * head_size * block_size
|
361 |
+
+ head_idx * head_size * block_size
|
362 |
+
+ head_offset * block_size
|
363 |
+
+ block_offset;
|
364 |
+
|
365 |
+
tgt_key_indices[j] = tgt_key_idx;
|
366 |
+
tgt_value_indices[j] = tgt_value_idx;
|
367 |
+
src_key_indices[j] = src_key_idx;
|
368 |
+
src_value_indices[j] = src_value_idx;
|
369 |
+
|
370 |
+
keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
|
371 |
+
values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
|
372 |
+
}
|
373 |
+
|
374 |
+
#pragma unroll
|
375 |
+
for (int j = 0; j < unroll_factor; ++j)
|
376 |
+
{
|
377 |
+
key[tgt_key_indices[j]] = keys_to_store[j];
|
378 |
+
value[tgt_value_indices[j]] = values_to_store[j];
|
379 |
+
}
|
380 |
+
}
|
381 |
+
}
|
382 |
+
|
383 |
+
} // namespace vllm
|
384 |
+
|
385 |
+
void gather_cached_kv(
|
386 |
+
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
|
387 |
+
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
|
388 |
+
torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
|
389 |
+
torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
|
390 |
+
torch::Tensor& slot_mapping) // [in] [num_tokens]
|
391 |
+
{
|
392 |
+
int num_tokens = key.size(0);
|
393 |
+
int num_heads = key.size(1);
|
394 |
+
int head_size = key.size(2);
|
395 |
+
int block_size = key_cache.size(3);
|
396 |
+
int x = key_cache.size(4);
|
397 |
+
|
398 |
+
int key_stride = key.stride(0);
|
399 |
+
int value_stride = value.stride(0);
|
400 |
+
|
401 |
+
dim3 grid(num_tokens);
|
402 |
+
dim3 block(std::min(num_heads * head_size, 512));
|
403 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
404 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
405 |
+
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
406 |
+
key.scalar_type(),
|
407 |
+
"gather_cached_kv_kernel_optimized",
|
408 |
+
[&] {
|
409 |
+
vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
|
410 |
+
key.data_ptr<scalar_t>(),
|
411 |
+
value.data_ptr<scalar_t>(),
|
412 |
+
key_cache.data_ptr<scalar_t>(),
|
413 |
+
value_cache.data_ptr<scalar_t>(),
|
414 |
+
slot_mapping.data_ptr<int>(),
|
415 |
+
key_stride,
|
416 |
+
value_stride,
|
417 |
+
num_heads,
|
418 |
+
head_size,
|
419 |
+
block_size,
|
420 |
+
x);
|
421 |
+
});
|
422 |
+
}
|
423 |
+
|
424 |
+
namespace vllm {
|
425 |
+
|
426 |
+
template<typename Tout, typename Tin>
|
427 |
+
__global__ void convert_fp8_e5m2_kernel(
|
428 |
+
const Tin* __restrict__ src_cache,
|
429 |
+
Tout* __restrict__ dst_cache,
|
430 |
+
const int64_t block_stride) {
|
431 |
+
const int64_t block_idx = blockIdx.x;
|
432 |
+
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
433 |
+
int64_t idx = block_idx * block_stride + i;
|
434 |
+
#ifdef ENABLE_FP8_E5M2
|
435 |
+
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
436 |
+
#else
|
437 |
+
assert(false);
|
438 |
+
#endif
|
439 |
+
}
|
440 |
+
}
|
441 |
+
|
442 |
+
} // namespace vllm
|
443 |
+
|
444 |
+
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
|
445 |
+
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
446 |
+
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
447 |
+
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
448 |
+
block_stride);
|
449 |
+
|
450 |
+
void convert_fp8_e5m2(
|
451 |
+
torch::Tensor& src_cache,
|
452 |
+
torch::Tensor& dst_cache)
|
453 |
+
{
|
454 |
+
int64_t num_blocks = src_cache.size(0);
|
455 |
+
int64_t block_stride = src_cache.stride(0);
|
456 |
+
|
457 |
+
dim3 grid(num_blocks);
|
458 |
+
dim3 block(std::min(block_stride, int64_t(512)));
|
459 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
460 |
+
|
461 |
+
if (src_cache.dtype() == at::ScalarType::Float) {
|
462 |
+
CALL_CONVERT_FP8_E5M2(uint8_t, float);
|
463 |
+
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
464 |
+
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
|
465 |
+
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
466 |
+
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
|
467 |
+
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
468 |
+
CALL_CONVERT_FP8_E5M2(float, uint8_t);
|
469 |
+
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
470 |
+
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
|
471 |
+
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
472 |
+
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
|
473 |
+
}
|
474 |
+
}
|
csrc/cuda_compat.h
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#ifndef USE_ROCM
|
4 |
+
#define VLLM_LDG(arg) __ldg(arg)
|
5 |
+
#else
|
6 |
+
#define VLLM_LDG(arg) *(arg)
|
7 |
+
#endif
|
8 |
+
|
9 |
+
#ifndef USE_ROCM
|
10 |
+
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
11 |
+
#else
|
12 |
+
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
13 |
+
#endif
|
14 |
+
|
15 |
+
#ifndef USE_ROCM
|
16 |
+
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
17 |
+
#else
|
18 |
+
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
19 |
+
#endif
|
20 |
+
|
21 |
+
#ifndef USE_ROCM
|
22 |
+
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
23 |
+
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
24 |
+
#else
|
25 |
+
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
26 |
+
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
27 |
+
#endif
|
28 |
+
|
csrc/cuda_utils.h
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/extension.h>
|
4 |
+
|
5 |
+
int get_device_attribute(
|
6 |
+
int attribute,
|
7 |
+
int device_id);
|
8 |
+
|
9 |
+
int get_max_shared_memory_per_block_device_attribute(
|
10 |
+
int device_id);
|
csrc/cuda_utils_kernels.cu
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifdef USE_ROCM
|
2 |
+
#include <hip/hip_runtime.h>
|
3 |
+
#include <hip/hip_runtime_api.h>
|
4 |
+
#endif
|
5 |
+
int get_device_attribute(
|
6 |
+
int attribute,
|
7 |
+
int device_id)
|
8 |
+
{
|
9 |
+
int device, value;
|
10 |
+
if (device_id < 0) {
|
11 |
+
cudaGetDevice(&device);
|
12 |
+
}
|
13 |
+
else {
|
14 |
+
device = device_id;
|
15 |
+
}
|
16 |
+
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
17 |
+
return value;
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
int get_max_shared_memory_per_block_device_attribute(
|
22 |
+
int device_id)
|
23 |
+
{
|
24 |
+
int attribute;
|
25 |
+
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
26 |
+
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
27 |
+
|
28 |
+
#ifdef USE_ROCM
|
29 |
+
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
30 |
+
#else
|
31 |
+
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
|
32 |
+
#endif
|
33 |
+
|
34 |
+
return get_device_attribute(attribute, device_id);
|
35 |
+
}
|
csrc/custom_all_reduce.cu
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/cuda/Exceptions.h>
|
2 |
+
#include <c10/cuda/CUDAGuard.h>
|
3 |
+
#include <c10/cuda/CUDAStream.h>
|
4 |
+
#include <torch/extension.h>
|
5 |
+
|
6 |
+
#include "custom_all_reduce.cuh"
|
7 |
+
|
8 |
+
// fake pointer type
|
9 |
+
using fptr_t = uint64_t;
|
10 |
+
static_assert(sizeof(void *) == sizeof(fptr_t));
|
11 |
+
|
12 |
+
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
13 |
+
const std::vector<std::string> &handles,
|
14 |
+
const std::vector<int64_t> &offsets, int rank,
|
15 |
+
bool full_nvlink) {
|
16 |
+
int world_size = offsets.size();
|
17 |
+
if (world_size > 8)
|
18 |
+
throw std::invalid_argument("world size > 8 is not supported");
|
19 |
+
if (world_size % 2 != 0)
|
20 |
+
throw std::invalid_argument("Odd num gpus is not supported for now");
|
21 |
+
if (world_size != handles.size())
|
22 |
+
throw std::invalid_argument(
|
23 |
+
"handles length should equal to offsets length");
|
24 |
+
if (rank < 0 || rank >= world_size)
|
25 |
+
throw std::invalid_argument("invalid rank passed in");
|
26 |
+
|
27 |
+
cudaIpcMemHandle_t ipc_handles[8];
|
28 |
+
for (int i = 0; i < world_size; i++) {
|
29 |
+
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
|
30 |
+
}
|
31 |
+
return (fptr_t) new vllm::CustomAllreduce(
|
32 |
+
reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
|
33 |
+
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
34 |
+
}
|
35 |
+
|
36 |
+
/**
|
37 |
+
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
38 |
+
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
39 |
+
* because it allows transpose of contiguous slice (i.e. slicing the first
|
40 |
+
* dimension). Currently, we require this because stride information is not
|
41 |
+
* passed into the kernels and we treat input tensors as flat.
|
42 |
+
*
|
43 |
+
* Examples
|
44 |
+
* A = torch.zeros(3, 3, 3)
|
45 |
+
* 1. A: OK
|
46 |
+
* 2. A[1:]: OK
|
47 |
+
* 3. A.permute(2, 0, 1): OK
|
48 |
+
* 4. A[1:].permute(2, 0, 1): OK
|
49 |
+
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
50 |
+
* 6. A[:, 1:, 1:]: Not OK
|
51 |
+
*/
|
52 |
+
bool _is_weak_contiguous(torch::Tensor &t) {
|
53 |
+
return t.is_contiguous() ||
|
54 |
+
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
55 |
+
t.numel() * t.element_size());
|
56 |
+
}
|
57 |
+
|
58 |
+
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
59 |
+
bool full_nvlink) {
|
60 |
+
auto inp_size = inp.numel() * inp.element_size();
|
61 |
+
// custom allreduce requires input byte size to be multiples of 16
|
62 |
+
if (inp_size % 16 != 0) return false;
|
63 |
+
if (!_is_weak_contiguous(inp)) return false;
|
64 |
+
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
|
65 |
+
// 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
|
66 |
+
// <= 512k
|
67 |
+
return world_size <= 4 && inp_size <= 512 * 1024;
|
68 |
+
}
|
69 |
+
|
70 |
+
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
71 |
+
cudaStream_t stream) {
|
72 |
+
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
73 |
+
TORCH_CHECK(_is_weak_contiguous(out));
|
74 |
+
switch (out.scalar_type()) {
|
75 |
+
case at::ScalarType::Float: {
|
76 |
+
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
|
77 |
+
reinterpret_cast<float *>(out.data_ptr()),
|
78 |
+
out.numel());
|
79 |
+
break;
|
80 |
+
}
|
81 |
+
case at::ScalarType::Half: {
|
82 |
+
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
|
83 |
+
reinterpret_cast<half *>(out.data_ptr()),
|
84 |
+
out.numel());
|
85 |
+
break;
|
86 |
+
}
|
87 |
+
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
88 |
+
case at::ScalarType::BFloat16: {
|
89 |
+
fa->allreduce<nv_bfloat16>(
|
90 |
+
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
|
91 |
+
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
|
92 |
+
break;
|
93 |
+
}
|
94 |
+
#endif
|
95 |
+
default:
|
96 |
+
throw std::runtime_error(
|
97 |
+
"custom allreduce only supports float32, float16 and bfloat16");
|
98 |
+
}
|
99 |
+
}
|
100 |
+
|
101 |
+
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
|
102 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
103 |
+
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
104 |
+
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
105 |
+
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
106 |
+
_all_reduce(_fa, inp, out, stream);
|
107 |
+
}
|
108 |
+
|
109 |
+
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
110 |
+
torch::Tensor &out) {
|
111 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
112 |
+
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
113 |
+
|
114 |
+
auto input_size = inp.numel() * inp.element_size();
|
115 |
+
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
116 |
+
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
117 |
+
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
118 |
+
"registered buffer is too small to contain the input");
|
119 |
+
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
120 |
+
input_size, cudaMemcpyDeviceToDevice, stream));
|
121 |
+
_all_reduce(_fa, reg_buffer, out, stream);
|
122 |
+
}
|
123 |
+
|
124 |
+
void dispose(fptr_t _fa) {
|
125 |
+
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
126 |
+
delete fa;
|
127 |
+
}
|
128 |
+
|
129 |
+
int meta_size() { return sizeof(vllm::Metadata); }
|
130 |
+
|
131 |
+
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
132 |
+
const std::vector<std::string> &handles,
|
133 |
+
const std::vector<int64_t> &offsets) {
|
134 |
+
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
135 |
+
fa->register_buffer(handles, offsets, t.data_ptr());
|
136 |
+
}
|
137 |
+
|
138 |
+
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
139 |
+
fptr_t _fa) {
|
140 |
+
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
141 |
+
return fa->get_graph_buffer_ipc_meta();
|
142 |
+
}
|
143 |
+
|
144 |
+
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
145 |
+
const std::vector<std::vector<int64_t>> &offsets) {
|
146 |
+
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
147 |
+
fa->register_graph_buffers(handles, offsets);
|
148 |
+
}
|
csrc/custom_all_reduce.cuh
ADDED
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <cuda.h>
|
4 |
+
#include <cuda_bf16.h>
|
5 |
+
#include <cuda_fp16.h>
|
6 |
+
#include <cuda_runtime.h>
|
7 |
+
|
8 |
+
#include <iostream>
|
9 |
+
#include <limits>
|
10 |
+
#include <map>
|
11 |
+
#include <unordered_map>
|
12 |
+
#include <vector>
|
13 |
+
|
14 |
+
#define CUDACHECK(cmd) \
|
15 |
+
do { \
|
16 |
+
cudaError_t e = cmd; \
|
17 |
+
if (e != cudaSuccess) { \
|
18 |
+
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
19 |
+
cudaGetErrorString(e)); \
|
20 |
+
exit(EXIT_FAILURE); \
|
21 |
+
} \
|
22 |
+
} while (0)
|
23 |
+
|
24 |
+
namespace vllm {
|
25 |
+
|
26 |
+
struct Signal {
|
27 |
+
alignas(64) union {
|
28 |
+
uint64_t flag;
|
29 |
+
unsigned char data[8];
|
30 |
+
} start;
|
31 |
+
alignas(64) union {
|
32 |
+
uint64_t flag;
|
33 |
+
unsigned char data[8];
|
34 |
+
} end;
|
35 |
+
};
|
36 |
+
|
37 |
+
struct Metadata {
|
38 |
+
alignas(128) Signal sg;
|
39 |
+
alignas(128) int counter;
|
40 |
+
};
|
41 |
+
static_assert(offsetof(Metadata, counter) == 128);
|
42 |
+
static_assert(sizeof(Metadata) == 256);
|
43 |
+
|
44 |
+
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
|
45 |
+
|
46 |
+
struct RankSignals {
|
47 |
+
volatile Signal *signals[8];
|
48 |
+
};
|
49 |
+
|
50 |
+
// like std::array, but aligned
|
51 |
+
template <typename T, int sz>
|
52 |
+
struct __align__(alignof(T) * sz) array_t {
|
53 |
+
T data[sz];
|
54 |
+
using type = T;
|
55 |
+
static constexpr int size = sz;
|
56 |
+
};
|
57 |
+
|
58 |
+
// use packed type to maximize memory efficiency
|
59 |
+
// goal: generate ld.128 and st.128 instructions
|
60 |
+
template <typename T>
|
61 |
+
struct packed_t {
|
62 |
+
// the (P)acked type for load/store
|
63 |
+
using P = array_t<T, 16 / sizeof(T)>;
|
64 |
+
// the (A)ccumulator type for reduction
|
65 |
+
using A = array_t<float, 16 / sizeof(T)>;
|
66 |
+
};
|
67 |
+
|
68 |
+
#define DINLINE __device__ __forceinline__
|
69 |
+
|
70 |
+
// scalar cast functions
|
71 |
+
DINLINE float upcast_s(half val) { return __half2float(val); }
|
72 |
+
|
73 |
+
template <typename T>
|
74 |
+
DINLINE T downcast_s(float val);
|
75 |
+
template <>
|
76 |
+
DINLINE half downcast_s(float val) {
|
77 |
+
return __float2half(val);
|
78 |
+
}
|
79 |
+
|
80 |
+
// scalar add functions
|
81 |
+
// for some reason when compiling with Pytorch, the + operator for half and
|
82 |
+
// bfloat is disabled so we call the intrinsics directly
|
83 |
+
DINLINE half &assign_add(half &a, half b) {
|
84 |
+
a = __hadd(a, b);
|
85 |
+
return a;
|
86 |
+
}
|
87 |
+
DINLINE float &assign_add(float &a, float b) { return a += b; }
|
88 |
+
|
89 |
+
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
90 |
+
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
91 |
+
template <>
|
92 |
+
DINLINE nv_bfloat16 downcast_s(float val) {
|
93 |
+
return __float2bfloat16(val);
|
94 |
+
}
|
95 |
+
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
|
96 |
+
a = __hadd(a, b);
|
97 |
+
return a;
|
98 |
+
}
|
99 |
+
#endif
|
100 |
+
|
101 |
+
template <typename T, int N>
|
102 |
+
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
|
103 |
+
#pragma unroll
|
104 |
+
for (int i = 0; i < N; i++) {
|
105 |
+
assign_add(a.data[i], b.data[i]);
|
106 |
+
}
|
107 |
+
return a;
|
108 |
+
}
|
109 |
+
|
110 |
+
template <typename T, int N>
|
111 |
+
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
112 |
+
if constexpr (std::is_same<T, float>::value) {
|
113 |
+
return val;
|
114 |
+
} else {
|
115 |
+
array_t<float, N> out;
|
116 |
+
#pragma unroll
|
117 |
+
for (int i = 0; i < N; i++) {
|
118 |
+
out.data[i] = upcast_s(val.data[i]);
|
119 |
+
}
|
120 |
+
return out;
|
121 |
+
}
|
122 |
+
}
|
123 |
+
|
124 |
+
template <typename O>
|
125 |
+
DINLINE O downcast(array_t<float, O::size> val) {
|
126 |
+
if constexpr (std::is_same<typename O::type, float>::value) {
|
127 |
+
return val;
|
128 |
+
} else {
|
129 |
+
O out;
|
130 |
+
#pragma unroll
|
131 |
+
for (int i = 0; i < O::size; i++) {
|
132 |
+
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
133 |
+
}
|
134 |
+
return out;
|
135 |
+
}
|
136 |
+
}
|
137 |
+
|
138 |
+
// compute flag at compile time
|
139 |
+
__host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
|
140 |
+
auto m = std::numeric_limits<uint64_t>::max();
|
141 |
+
return m >> ((8 - ngpus) * 8);
|
142 |
+
}
|
143 |
+
|
144 |
+
template <int ngpus>
|
145 |
+
DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
|
146 |
+
int rank) {
|
147 |
+
constexpr auto FLAG = compute_flag(ngpus);
|
148 |
+
if (blockIdx.x == 0) {
|
149 |
+
if (threadIdx.x < ngpus)
|
150 |
+
// simultaneously write to the corresponding byte to all other ranks.
|
151 |
+
// Latency = 1 p2p write
|
152 |
+
sg.signals[threadIdx.x]->start.data[rank] = 255;
|
153 |
+
else if (threadIdx.x == 32)
|
154 |
+
// reset
|
155 |
+
meta->sg.end.flag = 0;
|
156 |
+
}
|
157 |
+
if (threadIdx.x == 0) {
|
158 |
+
while (meta->sg.start.flag != FLAG)
|
159 |
+
;
|
160 |
+
}
|
161 |
+
__syncthreads();
|
162 |
+
}
|
163 |
+
|
164 |
+
template <int ngpus, bool final_sync = false>
|
165 |
+
DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
|
166 |
+
int rank) {
|
167 |
+
constexpr auto FLAG = compute_flag(ngpus);
|
168 |
+
__syncthreads();
|
169 |
+
__shared__ int num;
|
170 |
+
if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
|
171 |
+
__syncthreads();
|
172 |
+
|
173 |
+
// Only the last completing block can perform the end synchronization
|
174 |
+
// This can ensures when the final busy wait ends, all ranks must have
|
175 |
+
// finished reading each other's buffer.
|
176 |
+
if (num == gridDim.x - 1) {
|
177 |
+
if (threadIdx.x == 32) {
|
178 |
+
// reset in a different warp
|
179 |
+
meta->counter = 0;
|
180 |
+
meta->sg.start.flag = 0;
|
181 |
+
} else if (threadIdx.x < ngpus) {
|
182 |
+
// simultaneously write to the corresponding byte to all other ranks.
|
183 |
+
// Latency = 1 p2p write
|
184 |
+
sg.signals[threadIdx.x]->end.data[rank] = 255;
|
185 |
+
}
|
186 |
+
// if this is the final sync, only one block needs it
|
187 |
+
// because kernel exit can serve as sync
|
188 |
+
if constexpr (final_sync) {
|
189 |
+
if (threadIdx.x == 0) {
|
190 |
+
while (meta->sg.end.flag != FLAG)
|
191 |
+
;
|
192 |
+
}
|
193 |
+
}
|
194 |
+
}
|
195 |
+
if constexpr (!final_sync) {
|
196 |
+
if (threadIdx.x == 0) {
|
197 |
+
while (meta->sg.end.flag != FLAG)
|
198 |
+
;
|
199 |
+
}
|
200 |
+
__syncthreads();
|
201 |
+
}
|
202 |
+
}
|
203 |
+
|
204 |
+
template <typename P, int ngpus, typename A>
|
205 |
+
DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
206 |
+
A tmp = upcast(ptrs[0][idx]);
|
207 |
+
#pragma unroll
|
208 |
+
for (int i = 1; i < ngpus; i++) {
|
209 |
+
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
210 |
+
}
|
211 |
+
return downcast<P>(tmp);
|
212 |
+
}
|
213 |
+
|
214 |
+
template <typename T, int ngpus>
|
215 |
+
__global__ void __launch_bounds__(512, 1)
|
216 |
+
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
|
217 |
+
volatile Metadata *meta, T *__restrict__ result,
|
218 |
+
int rank, int size) {
|
219 |
+
using P = typename packed_t<T>::P;
|
220 |
+
using A = typename packed_t<T>::A;
|
221 |
+
// note: we don't reorder the address so the accumulation order is the same
|
222 |
+
// for all ranks, ensuring bitwise identical results
|
223 |
+
auto dp = *_dp;
|
224 |
+
start_sync<ngpus>(sg, meta, rank);
|
225 |
+
// do the actual reduction
|
226 |
+
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
227 |
+
idx += gridDim.x * blockDim.x) {
|
228 |
+
((P *)result)[idx] =
|
229 |
+
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
230 |
+
}
|
231 |
+
end_sync<ngpus, true>(sg, meta, rank);
|
232 |
+
}
|
233 |
+
|
234 |
+
template <typename P>
|
235 |
+
DINLINE P *get_tmp_buf(volatile Signal *sg) {
|
236 |
+
return (P *)(((Metadata *)sg) + 1);
|
237 |
+
}
|
238 |
+
|
239 |
+
template <typename T, int ngpus>
|
240 |
+
__global__ void __launch_bounds__(512, 1)
|
241 |
+
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
|
242 |
+
volatile Metadata *meta, T *__restrict__ result,
|
243 |
+
int rank, int size) {
|
244 |
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
245 |
+
int stride = gridDim.x * blockDim.x;
|
246 |
+
using P = typename packed_t<T>::P;
|
247 |
+
using A = typename packed_t<T>::A;
|
248 |
+
int part = size / ngpus;
|
249 |
+
int start = rank * part;
|
250 |
+
int end = rank == ngpus - 1 ? size : start + part;
|
251 |
+
const P *ptrs[ngpus];
|
252 |
+
P *tmps[ngpus];
|
253 |
+
#pragma unroll
|
254 |
+
for (int i = 0; i < ngpus; i++) {
|
255 |
+
int target = (rank + i) % ngpus;
|
256 |
+
ptrs[i] = (const P *)_dp->ptrs[target];
|
257 |
+
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
258 |
+
}
|
259 |
+
auto tmp_out = tmps[0];
|
260 |
+
start_sync<ngpus>(sg, meta, rank);
|
261 |
+
// stage 1: reduce scatter
|
262 |
+
for (int idx = start + tid; idx < end; idx += stride) {
|
263 |
+
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
264 |
+
}
|
265 |
+
// Maybe TODO: replace this with per-block release-acquire
|
266 |
+
// can save about 1-2us (not a lot though)
|
267 |
+
end_sync<ngpus>(sg, meta, rank);
|
268 |
+
|
269 |
+
// stage 2: allgather
|
270 |
+
for (int idx = tid; idx < part; idx += stride) {
|
271 |
+
#pragma unroll
|
272 |
+
for (int i = 0; i < ngpus; i++) {
|
273 |
+
int dst_idx = ((rank + i) % ngpus) * part + idx;
|
274 |
+
((P *)result)[dst_idx] = tmps[i][idx];
|
275 |
+
}
|
276 |
+
}
|
277 |
+
// process the last larger partition
|
278 |
+
int remaining = size - part * ngpus;
|
279 |
+
if (tid < remaining) {
|
280 |
+
int dst_idx = tid + part * ngpus;
|
281 |
+
((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
|
282 |
+
}
|
283 |
+
|
284 |
+
// faster than this
|
285 |
+
// for (int idx = tid; idx < size; idx += stride) {
|
286 |
+
// int target_rank = idx / part;
|
287 |
+
// if (target_rank == ngpus) target_rank -= 1;
|
288 |
+
// ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
|
289 |
+
// }
|
290 |
+
}
|
291 |
+
|
292 |
+
template <typename T, int ngpus>
|
293 |
+
__global__ void __launch_bounds__(512, 1)
|
294 |
+
cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
|
295 |
+
volatile Metadata *meta,
|
296 |
+
T *__restrict__ result, int rank,
|
297 |
+
int size) {
|
298 |
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
299 |
+
int stride = gridDim.x * blockDim.x;
|
300 |
+
using P = typename packed_t<T>::P;
|
301 |
+
using A = typename packed_t<T>::A;
|
302 |
+
auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
|
303 |
+
constexpr int hg = ngpus / 2;
|
304 |
+
// Actually not quite half butterfly.
|
305 |
+
// This is an all-to-all within each group containing half of the ranks
|
306 |
+
// followed by cross-group add. Equivalent to half butterfly when there
|
307 |
+
// are 4 GPUs, a common case for PCIe cards like T4 and A10.
|
308 |
+
const P *ptrs[hg];
|
309 |
+
{
|
310 |
+
int start = rank - rank % hg;
|
311 |
+
#pragma unroll
|
312 |
+
for (int i = 0; i < hg; i++) {
|
313 |
+
ptrs[i] = (const P *)_dp->ptrs[i + start];
|
314 |
+
}
|
315 |
+
}
|
316 |
+
start_sync<ngpus>(sg, meta, rank);
|
317 |
+
for (int idx = tid; idx < size; idx += stride) {
|
318 |
+
tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
|
319 |
+
}
|
320 |
+
end_sync<ngpus>(sg, meta, rank);
|
321 |
+
|
322 |
+
auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
|
323 |
+
// do the cross group reduction
|
324 |
+
for (int idx = tid; idx < size; idx += stride) {
|
325 |
+
auto tmp = tmp_out[idx];
|
326 |
+
packed_assign_add(tmp, src[idx]);
|
327 |
+
((P *)result)[idx] = tmp;
|
328 |
+
}
|
329 |
+
}
|
330 |
+
|
331 |
+
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
332 |
+
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
|
333 |
+
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
|
334 |
+
|
335 |
+
class CustomAllreduce {
|
336 |
+
public:
|
337 |
+
int rank_;
|
338 |
+
int world_size_;
|
339 |
+
bool full_nvlink_;
|
340 |
+
|
341 |
+
// below are device pointers
|
342 |
+
RankSignals sg_;
|
343 |
+
std::unordered_map<void *, RankData *> buffers_;
|
344 |
+
Metadata *meta_;
|
345 |
+
|
346 |
+
// stores the registered device pointers from all ranks
|
347 |
+
RankData *d_rank_data_base_, *d_rank_data_end_;
|
348 |
+
std::vector<void *> graph_unreg_buffers_;
|
349 |
+
// a map from IPC handles to opened IPC pointers
|
350 |
+
std::map<IPC_KEY, char *> ipc_handles_;
|
351 |
+
|
352 |
+
/**
|
353 |
+
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
354 |
+
*
|
355 |
+
* There's a total of sizeof(Metadata) of prefix before the actual data,
|
356 |
+
* so meta + 1 points to actual temporary buffer.
|
357 |
+
*
|
358 |
+
* note: this class does not own any device memory. Any required buffers
|
359 |
+
* are passed in from the constructor
|
360 |
+
*/
|
361 |
+
CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
|
362 |
+
const cudaIpcMemHandle_t *handles,
|
363 |
+
const std::vector<int64_t> &offsets, int rank,
|
364 |
+
bool full_nvlink = true)
|
365 |
+
: rank_(rank),
|
366 |
+
world_size_(offsets.size()),
|
367 |
+
full_nvlink_(full_nvlink),
|
368 |
+
meta_(meta),
|
369 |
+
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
|
370 |
+
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
371 |
+
for (int i = 0; i < world_size_; i++) {
|
372 |
+
Metadata *rank_meta;
|
373 |
+
if (i != rank_) {
|
374 |
+
char *handle = open_ipc_handle(&handles[i]);
|
375 |
+
handle += offsets[i];
|
376 |
+
rank_meta = (Metadata *)handle;
|
377 |
+
} else {
|
378 |
+
rank_meta = meta_;
|
379 |
+
}
|
380 |
+
sg_.signals[i] = &rank_meta->sg;
|
381 |
+
}
|
382 |
+
}
|
383 |
+
|
384 |
+
char *open_ipc_handle(const void *ipc_handle) {
|
385 |
+
auto [it, new_handle] =
|
386 |
+
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
|
387 |
+
if (new_handle) {
|
388 |
+
char *ipc_ptr;
|
389 |
+
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
|
390 |
+
*((const cudaIpcMemHandle_t *)ipc_handle),
|
391 |
+
cudaIpcMemLazyEnablePeerAccess));
|
392 |
+
it->second = ipc_ptr;
|
393 |
+
}
|
394 |
+
return it->second;
|
395 |
+
}
|
396 |
+
|
397 |
+
std::pair<std::vector<uint8_t>, std::vector<int64_t>>
|
398 |
+
get_graph_buffer_ipc_meta() {
|
399 |
+
auto num_buffers = graph_unreg_buffers_.size();
|
400 |
+
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
401 |
+
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
402 |
+
std::vector<int64_t> offsets(num_buffers);
|
403 |
+
for (int i = 0; i < num_buffers; i++) {
|
404 |
+
auto ptr = graph_unreg_buffers_[i];
|
405 |
+
void *base_ptr;
|
406 |
+
// note: must share the base address of each allocation, or we get wrong
|
407 |
+
// address
|
408 |
+
if (cuPointerGetAttribute(&base_ptr,
|
409 |
+
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
410 |
+
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
411 |
+
throw std::runtime_error("failed to get pointer attr");
|
412 |
+
CUDACHECK(cudaIpcGetMemHandle(
|
413 |
+
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
|
414 |
+
offsets[i] = ((char *)ptr) - ((char *)base_ptr);
|
415 |
+
}
|
416 |
+
return std::make_pair(handles, offsets);
|
417 |
+
}
|
418 |
+
|
419 |
+
void check_rank_data_capacity(size_t num = 1) {
|
420 |
+
if (d_rank_data_base_ + num > d_rank_data_end_)
|
421 |
+
throw std::runtime_error(
|
422 |
+
"Rank data buffer is overflowed by " +
|
423 |
+
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
424 |
+
}
|
425 |
+
|
426 |
+
void register_buffer(const std::vector<std::string> &handles,
|
427 |
+
const std::vector<int64_t> &offsets, void *self) {
|
428 |
+
check_rank_data_capacity();
|
429 |
+
RankData data;
|
430 |
+
for (int i = 0; i < world_size_; i++) {
|
431 |
+
if (i != rank_) {
|
432 |
+
char *handle = open_ipc_handle(handles[i].data());
|
433 |
+
handle += offsets[i];
|
434 |
+
data.ptrs[i] = handle;
|
435 |
+
} else {
|
436 |
+
data.ptrs[i] = self;
|
437 |
+
}
|
438 |
+
}
|
439 |
+
auto d_data = d_rank_data_base_++;
|
440 |
+
CUDACHECK(
|
441 |
+
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
442 |
+
buffers_[self] = d_data;
|
443 |
+
}
|
444 |
+
|
445 |
+
// note: when registering graph buffers, we intentionally choose to not
|
446 |
+
// deduplicate the addresses. That means if the allocator reuses some
|
447 |
+
// addresses, they will be registered again. This is to account for the remote
|
448 |
+
// possibility of different allocation patterns between ranks. For example,
|
449 |
+
// rank 1 may get the same input address for the second allreduce, but rank 2
|
450 |
+
// got a different address. IPC handles have internal reference counting
|
451 |
+
// mechanism so overhead should be small.
|
452 |
+
void register_graph_buffers(
|
453 |
+
const std::vector<std::string> &handles,
|
454 |
+
const std::vector<std::vector<int64_t>> &offsets) {
|
455 |
+
auto num_buffers = graph_unreg_buffers_.size();
|
456 |
+
check_rank_data_capacity(num_buffers);
|
457 |
+
std::vector<RankData> rank_data(num_buffers);
|
458 |
+
for (int i = 0; i < num_buffers; i++) {
|
459 |
+
auto self_ptr = graph_unreg_buffers_[i];
|
460 |
+
auto &rd = rank_data[i];
|
461 |
+
for (int j = 0; j < world_size_; j++) {
|
462 |
+
if (j != rank_) {
|
463 |
+
char *handle =
|
464 |
+
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
465 |
+
handle += offsets[j][i];
|
466 |
+
rd.ptrs[j] = handle;
|
467 |
+
} else {
|
468 |
+
rd.ptrs[j] = self_ptr;
|
469 |
+
}
|
470 |
+
}
|
471 |
+
}
|
472 |
+
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
|
473 |
+
sizeof(RankData) * num_buffers,
|
474 |
+
cudaMemcpyHostToDevice));
|
475 |
+
d_rank_data_base_ += num_buffers;
|
476 |
+
graph_unreg_buffers_.clear();
|
477 |
+
}
|
478 |
+
|
479 |
+
/**
|
480 |
+
* This is the result after careful grid search. Using 36 blocks give the best
|
481 |
+
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
482 |
+
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
483 |
+
* Not quite sure the underlying reason, but my guess is that too many SMs
|
484 |
+
* will cause contention on NVLink bus.
|
485 |
+
*/
|
486 |
+
template <typename T>
|
487 |
+
void allreduce(cudaStream_t stream, T *input, T *output, int size,
|
488 |
+
int threads = 512, int block_limit = 36) {
|
489 |
+
auto d = packed_t<T>::P::size;
|
490 |
+
if (size % d != 0)
|
491 |
+
throw std::runtime_error(
|
492 |
+
"custom allreduce currently requires input length to be multiple "
|
493 |
+
"of " +
|
494 |
+
std::to_string(d));
|
495 |
+
|
496 |
+
RankData *ptrs;
|
497 |
+
cudaStreamCaptureStatus status;
|
498 |
+
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
499 |
+
if (status == cudaStreamCaptureStatusActive) {
|
500 |
+
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
501 |
+
graph_unreg_buffers_.push_back(input);
|
502 |
+
} else {
|
503 |
+
auto it = buffers_.find(input);
|
504 |
+
if (it == buffers_.end())
|
505 |
+
throw std::runtime_error(
|
506 |
+
"buffer address " +
|
507 |
+
std::to_string(reinterpret_cast<uint64_t>(input)) +
|
508 |
+
" is not registered!");
|
509 |
+
ptrs = it->second;
|
510 |
+
}
|
511 |
+
|
512 |
+
size /= d;
|
513 |
+
auto bytes = size * sizeof(typename packed_t<T>::P);
|
514 |
+
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
515 |
+
#define KL(ngpus, name) \
|
516 |
+
name<T, ngpus> \
|
517 |
+
<<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
|
518 |
+
#define REDUCE_CASE(ngpus) \
|
519 |
+
case ngpus: { \
|
520 |
+
if (world_size_ == 2) { \
|
521 |
+
KL(ngpus, cross_device_reduce_1stage); \
|
522 |
+
} else if (full_nvlink_) { \
|
523 |
+
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
524 |
+
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
525 |
+
KL(ngpus, cross_device_reduce_1stage); \
|
526 |
+
} else { \
|
527 |
+
KL(ngpus, cross_device_reduce_2stage); \
|
528 |
+
} \
|
529 |
+
} else { \
|
530 |
+
KL(ngpus, cross_device_reduce_half_butterfly); \
|
531 |
+
} \
|
532 |
+
break; \
|
533 |
+
}
|
534 |
+
|
535 |
+
switch (world_size_) {
|
536 |
+
REDUCE_CASE(2)
|
537 |
+
REDUCE_CASE(4)
|
538 |
+
REDUCE_CASE(6)
|
539 |
+
REDUCE_CASE(8)
|
540 |
+
default:
|
541 |
+
throw std::runtime_error(
|
542 |
+
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
543 |
+
"gpus = " +
|
544 |
+
std::to_string(world_size_));
|
545 |
+
}
|
546 |
+
#undef REDUCE_CASE
|
547 |
+
#undef KL
|
548 |
+
}
|
549 |
+
|
550 |
+
~CustomAllreduce() {
|
551 |
+
for (auto [_, ptr] : ipc_handles_) {
|
552 |
+
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
553 |
+
}
|
554 |
+
}
|
555 |
+
};
|
556 |
+
/**
|
557 |
+
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
558 |
+
a template instantiation:
|
559 |
+
* template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
|
560 |
+
int, int, int);
|
561 |
+
*/
|
562 |
+
} // namespace vllm
|
csrc/custom_all_reduce_test.cu
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* This is a standalone test for custom allreduce.
|
3 |
+
* To compile, make sure you have MPI and NCCL installed in your system.
|
4 |
+
* export MPI_HOME=XXX
|
5 |
+
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
6 |
+
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
|
7 |
+
*
|
8 |
+
* Warning: this C++ test is not designed to be very readable and was used
|
9 |
+
* during the rapid prototyping process.
|
10 |
+
*
|
11 |
+
* To run:
|
12 |
+
* mpirun -np 8 ./custom_all_reduce_test
|
13 |
+
*/
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <curand_kernel.h>
|
16 |
+
#include <stdio.h>
|
17 |
+
#include <stdlib.h>
|
18 |
+
|
19 |
+
#include <limits>
|
20 |
+
#include <vector>
|
21 |
+
|
22 |
+
#include "cuda_profiler_api.h"
|
23 |
+
#include "custom_all_reduce.cuh"
|
24 |
+
#include "mpi.h"
|
25 |
+
#include "nccl.h"
|
26 |
+
|
27 |
+
#define MPICHECK(cmd) \
|
28 |
+
do { \
|
29 |
+
int e = cmd; \
|
30 |
+
if (e != MPI_SUCCESS) { \
|
31 |
+
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
|
32 |
+
exit(EXIT_FAILURE); \
|
33 |
+
} \
|
34 |
+
} while (0)
|
35 |
+
|
36 |
+
#define NCCLCHECK(cmd) \
|
37 |
+
do { \
|
38 |
+
ncclResult_t r = cmd; \
|
39 |
+
if (r != ncclSuccess) { \
|
40 |
+
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
|
41 |
+
ncclGetErrorString(r)); \
|
42 |
+
exit(EXIT_FAILURE); \
|
43 |
+
} \
|
44 |
+
} while (0)
|
45 |
+
|
46 |
+
__global__ void dummy_kernel() {
|
47 |
+
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
48 |
+
}
|
49 |
+
|
50 |
+
template <typename T>
|
51 |
+
__global__ void set_data(T *data, int size, int myRank) {
|
52 |
+
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
53 |
+
idx += gridDim.x * blockDim.x) {
|
54 |
+
data[idx] = myRank * 0.11f;
|
55 |
+
}
|
56 |
+
}
|
57 |
+
|
58 |
+
template <typename T>
|
59 |
+
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
|
60 |
+
double *fdata2, int size) {
|
61 |
+
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
62 |
+
idx += gridDim.x * blockDim.x) {
|
63 |
+
fdata1[idx] = data1[idx];
|
64 |
+
fdata2[idx] = data2[idx];
|
65 |
+
}
|
66 |
+
}
|
67 |
+
|
68 |
+
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
|
69 |
+
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
70 |
+
idx += gridDim.x * blockDim.x) {
|
71 |
+
for (int i = 0; i < nRanks; i++) {
|
72 |
+
curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
|
73 |
+
}
|
74 |
+
}
|
75 |
+
}
|
76 |
+
|
77 |
+
template <typename T>
|
78 |
+
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
79 |
+
int myRank, int nRanks, int size) {
|
80 |
+
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
81 |
+
idx += gridDim.x * blockDim.x) {
|
82 |
+
double sum = 0.0;
|
83 |
+
for (int i = 0; i < nRanks; i++) {
|
84 |
+
double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
|
85 |
+
T hval = val; // downcast first
|
86 |
+
sum += static_cast<double>(hval);
|
87 |
+
if (i == myRank) data[idx] = hval;
|
88 |
+
}
|
89 |
+
ground_truth[idx] = sum;
|
90 |
+
}
|
91 |
+
}
|
92 |
+
|
93 |
+
template <typename T>
|
94 |
+
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
95 |
+
int data_size) {
|
96 |
+
T *result;
|
97 |
+
cudaStream_t stream;
|
98 |
+
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
99 |
+
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
|
100 |
+
CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
|
101 |
+
|
102 |
+
cudaIpcMemHandle_t self_data_handle;
|
103 |
+
cudaIpcMemHandle_t data_handles[8];
|
104 |
+
vllm::Metadata *buffer;
|
105 |
+
T *self_data_copy;
|
106 |
+
/**
|
107 |
+
* Allocate IPC buffer
|
108 |
+
*
|
109 |
+
* The first section is a temporary buffer for storing intermediate allreduce
|
110 |
+
* results, if a particular algorithm requires it. The second section is for
|
111 |
+
* the input to the allreduce. The actual API takes the input pointer as an
|
112 |
+
* argument (that is, they can and usually should be allocated separately).
|
113 |
+
* But since the input pointers and the temporary buffer all require IPC
|
114 |
+
* registration, they are allocated and registered together in the test for
|
115 |
+
* convenience.
|
116 |
+
*/
|
117 |
+
CUDACHECK(
|
118 |
+
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
119 |
+
CUDACHECK(cudaMemset(buffer, 0,
|
120 |
+
2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
121 |
+
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
|
122 |
+
CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
|
123 |
+
|
124 |
+
MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
|
125 |
+
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
|
126 |
+
MPI_BYTE, MPI_COMM_WORLD));
|
127 |
+
|
128 |
+
void *rank_data;
|
129 |
+
size_t rank_data_sz = 16 * 1024 * 1024;
|
130 |
+
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
|
131 |
+
std::vector<int64_t> offsets(nRanks, 0);
|
132 |
+
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
|
133 |
+
offsets, myRank);
|
134 |
+
auto *self_data =
|
135 |
+
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
|
136 |
+
sizeof(vllm::Metadata) + data_size * sizeof(T));
|
137 |
+
// hack buffer registration
|
138 |
+
{
|
139 |
+
std::vector<std::string> handles;
|
140 |
+
handles.reserve(nRanks);
|
141 |
+
for (int i = 0; i < nRanks; i++) {
|
142 |
+
char *begin = (char *)&data_handles[i];
|
143 |
+
char *end = (char *)&data_handles[i + 1];
|
144 |
+
handles.emplace_back(begin, end);
|
145 |
+
}
|
146 |
+
std::vector<int64_t> offsets(
|
147 |
+
nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T));
|
148 |
+
fa.register_buffer(handles, offsets, self_data);
|
149 |
+
}
|
150 |
+
|
151 |
+
double *ground_truth;
|
152 |
+
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
|
153 |
+
curandState_t *states;
|
154 |
+
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
|
155 |
+
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
|
156 |
+
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
|
157 |
+
nRanks, data_size);
|
158 |
+
CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
|
159 |
+
cudaMemcpyDeviceToDevice, stream));
|
160 |
+
cudaEvent_t start, stop;
|
161 |
+
CUDACHECK(cudaEventCreate(&start));
|
162 |
+
CUDACHECK(cudaEventCreate(&stop));
|
163 |
+
|
164 |
+
ncclDataType_t ncclDtype;
|
165 |
+
if (std::is_same<T, half>::value) {
|
166 |
+
ncclDtype = ncclFloat16;
|
167 |
+
} else if (std::is_same<T, nv_bfloat16>::value) {
|
168 |
+
ncclDtype = ncclBfloat16;
|
169 |
+
} else {
|
170 |
+
ncclDtype = ncclFloat;
|
171 |
+
}
|
172 |
+
|
173 |
+
dummy_kernel<<<1, 1, 0, stream>>>();
|
174 |
+
constexpr int warmup_iters = 5;
|
175 |
+
constexpr int num_iters = 25;
|
176 |
+
// warmup
|
177 |
+
for (int i = 0; i < warmup_iters; i++) {
|
178 |
+
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
179 |
+
stream));
|
180 |
+
}
|
181 |
+
CUDACHECK(cudaEventRecord(start, stream));
|
182 |
+
for (int i = 0; i < num_iters; i++) {
|
183 |
+
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
184 |
+
stream));
|
185 |
+
}
|
186 |
+
CUDACHECK(cudaEventRecord(stop, stream));
|
187 |
+
CUDACHECK(cudaStreamSynchronize(stream));
|
188 |
+
float allreduce_ms = 0;
|
189 |
+
cudaEventElapsedTime(&allreduce_ms, start, stop);
|
190 |
+
|
191 |
+
// if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>();
|
192 |
+
// set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank);
|
193 |
+
|
194 |
+
dummy_kernel<<<1, 1, 0, stream>>>();
|
195 |
+
// warm up
|
196 |
+
for (int i = 0; i < warmup_iters; i++) {
|
197 |
+
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
198 |
+
}
|
199 |
+
CUDACHECK(cudaEventRecord(start, stream));
|
200 |
+
for (int i = 0; i < num_iters; i++) {
|
201 |
+
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
202 |
+
}
|
203 |
+
CUDACHECK(cudaEventRecord(stop, stream));
|
204 |
+
CUDACHECK(cudaStreamSynchronize(stream));
|
205 |
+
|
206 |
+
float duration_ms = 0;
|
207 |
+
cudaEventElapsedTime(&duration_ms, start, stop);
|
208 |
+
if (myRank == 0)
|
209 |
+
printf(
|
210 |
+
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
|
211 |
+
"time:%.2fus\n",
|
212 |
+
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
|
213 |
+
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
|
214 |
+
|
215 |
+
// And wait for all the queued up work to complete
|
216 |
+
CUDACHECK(cudaStreamSynchronize(stream));
|
217 |
+
|
218 |
+
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
|
219 |
+
ncclSum, comm, stream));
|
220 |
+
|
221 |
+
double *nccl_result, *my_result;
|
222 |
+
CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
|
223 |
+
CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
|
224 |
+
|
225 |
+
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
|
226 |
+
my_result, data_size);
|
227 |
+
CUDACHECK(cudaStreamSynchronize(stream));
|
228 |
+
|
229 |
+
for (unsigned long j = 0; j < data_size; j++) {
|
230 |
+
auto diff = abs(nccl_result[j] - my_result[j]);
|
231 |
+
if (diff >= 1e-2) {
|
232 |
+
printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
233 |
+
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
234 |
+
break;
|
235 |
+
}
|
236 |
+
}
|
237 |
+
|
238 |
+
long double nccl_diffs = 0.0;
|
239 |
+
long double my_diffs = 0.0;
|
240 |
+
for (int j = 0; j < data_size; j++) {
|
241 |
+
nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
242 |
+
my_diffs += abs(my_result[j] - ground_truth[j]);
|
243 |
+
}
|
244 |
+
if (myRank == 0)
|
245 |
+
std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
246 |
+
<< " me: " << my_diffs / data_size << std::endl;
|
247 |
+
|
248 |
+
CUDACHECK(cudaFree(result));
|
249 |
+
CUDACHECK(cudaFree(self_data_copy));
|
250 |
+
CUDACHECK(cudaFree(rank_data));
|
251 |
+
CUDACHECK(cudaFree(buffer));
|
252 |
+
CUDACHECK(cudaFree(states));
|
253 |
+
CUDACHECK(cudaFreeHost(ground_truth));
|
254 |
+
CUDACHECK(cudaFreeHost(nccl_result));
|
255 |
+
CUDACHECK(cudaFreeHost(my_result));
|
256 |
+
CUDACHECK(cudaStreamDestroy(stream));
|
257 |
+
}
|
258 |
+
|
259 |
+
int main(int argc, char **argv) {
|
260 |
+
int nRanks, myRank;
|
261 |
+
MPICHECK(MPI_Init(&argc, &argv));
|
262 |
+
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
|
263 |
+
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
|
264 |
+
CUDACHECK(cudaSetDevice(myRank));
|
265 |
+
ncclUniqueId id;
|
266 |
+
ncclComm_t comm;
|
267 |
+
if (myRank == 0) ncclGetUniqueId(&id);
|
268 |
+
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
|
269 |
+
MPI_COMM_WORLD));
|
270 |
+
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
|
271 |
+
|
272 |
+
cudaProfilerStart();
|
273 |
+
// for (int threads : {256, 512}) {
|
274 |
+
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
275 |
+
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
|
276 |
+
// }
|
277 |
+
// }
|
278 |
+
for (int sz = 512; sz <= (32 << 20); sz *= 2) {
|
279 |
+
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 50);
|
280 |
+
}
|
281 |
+
|
282 |
+
cudaProfilerStop();
|
283 |
+
return EXIT_SUCCESS;
|
284 |
+
}
|
csrc/dispatch_utils.h
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
4 |
+
*/
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include <torch/extension.h>
|
8 |
+
|
9 |
+
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
10 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
11 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
12 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
13 |
+
|
14 |
+
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
15 |
+
AT_DISPATCH_SWITCH( \
|
16 |
+
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
17 |
+
|
18 |
+
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
19 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
20 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
21 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
22 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
23 |
+
|
24 |
+
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
25 |
+
AT_DISPATCH_SWITCH( \
|
26 |
+
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
27 |
+
|
28 |
+
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
29 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
30 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
31 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
32 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
33 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
34 |
+
|
35 |
+
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
36 |
+
AT_DISPATCH_SWITCH( \
|
37 |
+
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
csrc/layernorm_kernels.cu
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <ATen/cuda/CUDAContext.h>
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
|
5 |
+
#include "dispatch_utils.h"
|
6 |
+
#include "reduction_utils.cuh"
|
7 |
+
|
8 |
+
namespace vllm {
|
9 |
+
|
10 |
+
// TODO(woosuk): Further optimize this kernel.
|
11 |
+
template<typename scalar_t>
|
12 |
+
__global__ void rms_norm_kernel(
|
13 |
+
scalar_t* __restrict__ out, // [..., hidden_size]
|
14 |
+
const scalar_t* __restrict__ input, // [..., hidden_size]
|
15 |
+
const scalar_t* __restrict__ weight, // [hidden_size]
|
16 |
+
const float epsilon,
|
17 |
+
const int num_tokens,
|
18 |
+
const int hidden_size) {
|
19 |
+
__shared__ float s_variance;
|
20 |
+
float variance = 0.0f;
|
21 |
+
|
22 |
+
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
23 |
+
const float x = (float) input[blockIdx.x * hidden_size + idx];
|
24 |
+
variance += x * x;
|
25 |
+
}
|
26 |
+
variance = blockReduceSum<float>(variance);
|
27 |
+
if (threadIdx.x == 0) {
|
28 |
+
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
29 |
+
}
|
30 |
+
__syncthreads();
|
31 |
+
|
32 |
+
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
33 |
+
float x = (float) input[blockIdx.x * hidden_size + idx];
|
34 |
+
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
35 |
+
}
|
36 |
+
}
|
37 |
+
|
38 |
+
// TODO: Further optimize this kernel.
|
39 |
+
template<typename scalar_t>
|
40 |
+
__global__ void fused_add_rms_norm_kernel(
|
41 |
+
scalar_t* __restrict__ input, // [..., hidden_size]
|
42 |
+
scalar_t* __restrict__ residual, // [..., hidden_size]
|
43 |
+
const scalar_t* __restrict__ weight, // [hidden_size]
|
44 |
+
const float epsilon,
|
45 |
+
const int num_tokens,
|
46 |
+
const int hidden_size) {
|
47 |
+
__shared__ float s_variance;
|
48 |
+
float variance = 0.0f;
|
49 |
+
|
50 |
+
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
51 |
+
float x = (float) input[blockIdx.x * hidden_size + idx];
|
52 |
+
x += (float) residual[blockIdx.x * hidden_size + idx];
|
53 |
+
variance += x * x;
|
54 |
+
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
|
55 |
+
}
|
56 |
+
variance = blockReduceSum<float>(variance);
|
57 |
+
if (threadIdx.x == 0) {
|
58 |
+
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
59 |
+
}
|
60 |
+
__syncthreads();
|
61 |
+
|
62 |
+
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
63 |
+
float x = (float) residual[blockIdx.x * hidden_size + idx];
|
64 |
+
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
65 |
+
}
|
66 |
+
}
|
67 |
+
|
68 |
+
} // namespace vllm
|
69 |
+
|
70 |
+
void rms_norm(
|
71 |
+
torch::Tensor& out, // [..., hidden_size]
|
72 |
+
torch::Tensor& input, // [..., hidden_size]
|
73 |
+
torch::Tensor& weight, // [hidden_size]
|
74 |
+
float epsilon) {
|
75 |
+
int hidden_size = input.size(-1);
|
76 |
+
int num_tokens = input.numel() / hidden_size;
|
77 |
+
|
78 |
+
dim3 grid(num_tokens);
|
79 |
+
dim3 block(std::min(hidden_size, 1024));
|
80 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
81 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
82 |
+
VLLM_DISPATCH_FLOATING_TYPES(
|
83 |
+
input.scalar_type(),
|
84 |
+
"rms_norm_kernel",
|
85 |
+
[&] {
|
86 |
+
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
87 |
+
out.data_ptr<scalar_t>(),
|
88 |
+
input.data_ptr<scalar_t>(),
|
89 |
+
weight.data_ptr<scalar_t>(),
|
90 |
+
epsilon,
|
91 |
+
num_tokens,
|
92 |
+
hidden_size);
|
93 |
+
});
|
94 |
+
}
|
95 |
+
|
96 |
+
void fused_add_rms_norm(
|
97 |
+
torch::Tensor& input, // [..., hidden_size]
|
98 |
+
torch::Tensor& residual, // [..., hidden_size]
|
99 |
+
torch::Tensor& weight, // [hidden_size]
|
100 |
+
float epsilon) {
|
101 |
+
int hidden_size = input.size(-1);
|
102 |
+
int num_tokens = input.numel() / hidden_size;
|
103 |
+
|
104 |
+
dim3 grid(num_tokens);
|
105 |
+
dim3 block(std::min(hidden_size, 1024));
|
106 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
107 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
108 |
+
VLLM_DISPATCH_FLOATING_TYPES(
|
109 |
+
input.scalar_type(),
|
110 |
+
"fused_add_rms_norm_kernel",
|
111 |
+
[&] {
|
112 |
+
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
113 |
+
input.data_ptr<scalar_t>(),
|
114 |
+
residual.data_ptr<scalar_t>(),
|
115 |
+
weight.data_ptr<scalar_t>(),
|
116 |
+
epsilon,
|
117 |
+
num_tokens,
|
118 |
+
hidden_size);
|
119 |
+
});
|
120 |
+
}
|
csrc/moe_align_block_size_kernels.cu
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <ATen/cuda/CUDAContext.h>
|
3 |
+
|
4 |
+
#include <ATen/ATen.h>
|
5 |
+
#include <THC/THCAtomics.cuh>
|
6 |
+
|
7 |
+
#include "cuda_compat.h"
|
8 |
+
#include "dispatch_utils.h"
|
9 |
+
|
10 |
+
const static size_t NUM_MAX_EXPERTS = 64;
|
11 |
+
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
|
12 |
+
|
13 |
+
namespace vllm {
|
14 |
+
template <typename scalar_t>
|
15 |
+
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
16 |
+
int32_t *sorted_token_ids,
|
17 |
+
int32_t *expert_ids,
|
18 |
+
int32_t *total_tokens_post_pad,
|
19 |
+
int32_t num_experts,
|
20 |
+
int32_t block_size,
|
21 |
+
size_t numel) {
|
22 |
+
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
23 |
+
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
24 |
+
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
|
25 |
+
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
|
26 |
+
for (int i = 0; i < num_experts; ++i) {
|
27 |
+
tokens_cnts[threadIdx.x + 1][i] = 0;
|
28 |
+
}
|
29 |
+
|
30 |
+
/**
|
31 |
+
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
32 |
+
* which counts how many tokens in the token shard of thread_index are assigned
|
33 |
+
* to expert expert_index.
|
34 |
+
*/
|
35 |
+
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
36 |
+
++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
|
37 |
+
}
|
38 |
+
|
39 |
+
__syncthreads();
|
40 |
+
|
41 |
+
// For each expert we accumulate the token counts from the different threads.
|
42 |
+
tokens_cnts[0][threadIdx.x] = 0;
|
43 |
+
for (int i = 1; i <= blockDim.x; ++i) {
|
44 |
+
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
|
45 |
+
}
|
46 |
+
|
47 |
+
__syncthreads();
|
48 |
+
|
49 |
+
// We accumulate the token counts of all experts in thread 0.
|
50 |
+
if (threadIdx.x == 0) {
|
51 |
+
cumsum[0] = 0;
|
52 |
+
for (int i = 1; i <= num_experts; ++i) {
|
53 |
+
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
|
54 |
+
}
|
55 |
+
*total_tokens_post_pad = cumsum[num_experts];
|
56 |
+
}
|
57 |
+
|
58 |
+
__syncthreads();
|
59 |
+
|
60 |
+
/**
|
61 |
+
* For each expert, each thread processes the tokens of the corresponding blocks
|
62 |
+
* and stores the corresponding expert_id for each block.
|
63 |
+
*/
|
64 |
+
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
|
65 |
+
expert_ids[i / block_size] = threadIdx.x;
|
66 |
+
}
|
67 |
+
|
68 |
+
/**
|
69 |
+
* Each thread processes a token shard, calculating the index of each token after
|
70 |
+
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
|
71 |
+
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
|
72 |
+
* where * represents a padding value(preset in python).
|
73 |
+
*/
|
74 |
+
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
75 |
+
int32_t expert_id = topk_ids[i];
|
76 |
+
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
77 |
+
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
|
78 |
+
* stores the indices of the tokens processed by the expert with expert_id within
|
79 |
+
* the current thread's token shard.
|
80 |
+
*/
|
81 |
+
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
|
82 |
+
sorted_token_ids[rank_post_pad] = i;
|
83 |
+
++tokens_cnts[threadIdx.x][expert_id];
|
84 |
+
}
|
85 |
+
}
|
86 |
+
}
|
87 |
+
|
88 |
+
void moe_align_block_size(
|
89 |
+
torch::Tensor topk_ids,
|
90 |
+
int num_experts,
|
91 |
+
int block_size,
|
92 |
+
torch::Tensor sorted_token_ids,
|
93 |
+
torch::Tensor experts_ids,
|
94 |
+
torch::Tensor num_tokens_post_pad) {
|
95 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
96 |
+
assert(num_experts <= NUM_MAX_EXPERTS);
|
97 |
+
VLLM_DISPATCH_INTEGRAL_TYPES(
|
98 |
+
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
99 |
+
vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
|
100 |
+
topk_ids.data_ptr<scalar_t>(),
|
101 |
+
sorted_token_ids.data_ptr<int32_t>(),
|
102 |
+
experts_ids.data_ptr<int32_t>(),
|
103 |
+
num_tokens_post_pad.data_ptr<int32_t>(),
|
104 |
+
num_experts,
|
105 |
+
block_size,
|
106 |
+
topk_ids.numel());
|
107 |
+
});
|
108 |
+
}
|
csrc/ops.h
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/extension.h>
|
4 |
+
|
5 |
+
void paged_attention_v1(
|
6 |
+
torch::Tensor& out,
|
7 |
+
torch::Tensor& query,
|
8 |
+
torch::Tensor& key_cache,
|
9 |
+
torch::Tensor& value_cache,
|
10 |
+
int num_kv_heads,
|
11 |
+
float scale,
|
12 |
+
torch::Tensor& block_tables,
|
13 |
+
torch::Tensor& context_lens,
|
14 |
+
int block_size,
|
15 |
+
int max_context_len,
|
16 |
+
const c10::optional<torch::Tensor>& alibi_slopes,
|
17 |
+
const std::string& kv_cache_dtype);
|
18 |
+
|
19 |
+
void paged_attention_v2(
|
20 |
+
torch::Tensor& out,
|
21 |
+
torch::Tensor& exp_sums,
|
22 |
+
torch::Tensor& max_logits,
|
23 |
+
torch::Tensor& tmp_out,
|
24 |
+
torch::Tensor& query,
|
25 |
+
torch::Tensor& key_cache,
|
26 |
+
torch::Tensor& value_cache,
|
27 |
+
int num_kv_heads,
|
28 |
+
float scale,
|
29 |
+
torch::Tensor& block_tables,
|
30 |
+
torch::Tensor& context_lens,
|
31 |
+
int block_size,
|
32 |
+
int max_context_len,
|
33 |
+
const c10::optional<torch::Tensor>& alibi_slopes,
|
34 |
+
const std::string& kv_cache_dtype);
|
35 |
+
|
36 |
+
void rms_norm(
|
37 |
+
torch::Tensor& out,
|
38 |
+
torch::Tensor& input,
|
39 |
+
torch::Tensor& weight,
|
40 |
+
float epsilon);
|
41 |
+
|
42 |
+
void fused_add_rms_norm(
|
43 |
+
torch::Tensor& input,
|
44 |
+
torch::Tensor& residual,
|
45 |
+
torch::Tensor& weight,
|
46 |
+
float epsilon);
|
47 |
+
|
48 |
+
void rotary_embedding(
|
49 |
+
torch::Tensor& positions,
|
50 |
+
torch::Tensor& query,
|
51 |
+
torch::Tensor& key,
|
52 |
+
int head_size,
|
53 |
+
torch::Tensor& cos_sin_cache,
|
54 |
+
bool is_neox);
|
55 |
+
|
56 |
+
void silu_and_mul(
|
57 |
+
torch::Tensor& out,
|
58 |
+
torch::Tensor& input);
|
59 |
+
|
60 |
+
void gelu_new(
|
61 |
+
torch::Tensor& out,
|
62 |
+
torch::Tensor& input);
|
63 |
+
|
64 |
+
void gelu_fast(
|
65 |
+
torch::Tensor& out,
|
66 |
+
torch::Tensor& input);
|
67 |
+
|
68 |
+
#ifndef USE_ROCM
|
69 |
+
torch::Tensor awq_gemm(
|
70 |
+
torch::Tensor _in_feats,
|
71 |
+
torch::Tensor _kernel,
|
72 |
+
torch::Tensor _scaling_factors,
|
73 |
+
torch::Tensor _zeros,
|
74 |
+
int split_k_iters);
|
75 |
+
|
76 |
+
torch::Tensor awq_dequantize(
|
77 |
+
torch::Tensor _kernel,
|
78 |
+
torch::Tensor _scaling_factors,
|
79 |
+
torch::Tensor _zeros,
|
80 |
+
int split_k_iters,
|
81 |
+
int thx,
|
82 |
+
int thy);
|
83 |
+
#endif
|
84 |
+
|
85 |
+
void squeezellm_gemm(
|
86 |
+
torch::Tensor vec,
|
87 |
+
torch::Tensor mat,
|
88 |
+
torch::Tensor mul,
|
89 |
+
torch::Tensor lookup_table);
|
90 |
+
|
91 |
+
torch::Tensor gptq_gemm(
|
92 |
+
torch::Tensor a,
|
93 |
+
torch::Tensor b_q_weight,
|
94 |
+
torch::Tensor b_gptq_qzeros,
|
95 |
+
torch::Tensor b_gptq_scales,
|
96 |
+
torch::Tensor b_g_idx,
|
97 |
+
bool use_exllama);
|
98 |
+
|
99 |
+
void gptq_shuffle(
|
100 |
+
torch::Tensor q_weight,
|
101 |
+
torch::Tensor q_perm);
|
102 |
+
|
103 |
+
void moe_align_block_size(
|
104 |
+
torch::Tensor topk_ids,
|
105 |
+
int num_experts,
|
106 |
+
int block_size,
|
107 |
+
torch::Tensor sorted_token_ids,
|
108 |
+
torch::Tensor experts_ids,
|
109 |
+
torch::Tensor num_tokens_post_pad);
|
110 |
+
|
111 |
+
#ifndef USE_ROCM
|
112 |
+
using fptr_t = uint64_t;
|
113 |
+
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
114 |
+
const std::vector<std::string> &handles,
|
115 |
+
const std::vector<int64_t> &offsets, int rank,
|
116 |
+
bool full_nvlink);
|
117 |
+
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
118 |
+
bool full_nvlink);
|
119 |
+
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
|
120 |
+
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
121 |
+
torch::Tensor &out);
|
122 |
+
void dispose(fptr_t _fa);
|
123 |
+
int meta_size();
|
124 |
+
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
125 |
+
const std::vector<std::string> &handles,
|
126 |
+
const std::vector<int64_t> &offsets);
|
127 |
+
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
128 |
+
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
129 |
+
const std::vector<std::vector<int64_t>> &offsets);
|
130 |
+
#endif
|
csrc/pos_encoding_kernels.cu
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <ATen/cuda/CUDAContext.h>
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
|
5 |
+
#include "cuda_compat.h"
|
6 |
+
#include "dispatch_utils.h"
|
7 |
+
|
8 |
+
namespace vllm {
|
9 |
+
|
10 |
+
template<typename scalar_t, bool IS_NEOX>
|
11 |
+
inline __device__ void apply_rotary_embedding(
|
12 |
+
scalar_t* __restrict__ arr,
|
13 |
+
const scalar_t* __restrict__ cos_ptr,
|
14 |
+
const scalar_t* __restrict__ sin_ptr,
|
15 |
+
int rot_offset,
|
16 |
+
int embed_dim)
|
17 |
+
{
|
18 |
+
int x_index, y_index;
|
19 |
+
scalar_t cos, sin;
|
20 |
+
if (IS_NEOX) {
|
21 |
+
// GPT-NeoX style rotary embedding.
|
22 |
+
x_index = rot_offset;
|
23 |
+
y_index = embed_dim + rot_offset;
|
24 |
+
cos = VLLM_LDG(cos_ptr + x_index);
|
25 |
+
sin = VLLM_LDG(sin_ptr + x_index);
|
26 |
+
} else {
|
27 |
+
// GPT-J style rotary embedding.
|
28 |
+
x_index = 2 * rot_offset;
|
29 |
+
y_index = 2 * rot_offset + 1;
|
30 |
+
cos = VLLM_LDG(cos_ptr + x_index / 2);
|
31 |
+
sin = VLLM_LDG(sin_ptr + x_index / 2);
|
32 |
+
}
|
33 |
+
|
34 |
+
const scalar_t x = arr[x_index];
|
35 |
+
const scalar_t y = arr[y_index];
|
36 |
+
arr[x_index] = x * cos - y * sin;
|
37 |
+
arr[y_index] = y * cos + x * sin;
|
38 |
+
}
|
39 |
+
|
40 |
+
template<typename scalar_t, bool IS_NEOX>
|
41 |
+
__global__ void rotary_embedding_kernel(
|
42 |
+
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
43 |
+
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
44 |
+
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
45 |
+
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
46 |
+
const int rot_dim,
|
47 |
+
const int64_t query_stride,
|
48 |
+
const int64_t key_stride,
|
49 |
+
const int num_heads,
|
50 |
+
const int num_kv_heads,
|
51 |
+
const int head_size) {
|
52 |
+
// Each thread block is responsible for one token.
|
53 |
+
const int token_idx = blockIdx.x;
|
54 |
+
int64_t pos = positions[token_idx];
|
55 |
+
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
56 |
+
|
57 |
+
const int embed_dim = rot_dim / 2;
|
58 |
+
const scalar_t* cos_ptr = cache_ptr;
|
59 |
+
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
60 |
+
|
61 |
+
const int nq = num_heads * embed_dim;
|
62 |
+
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
63 |
+
const int head_idx = i / embed_dim;
|
64 |
+
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
65 |
+
const int rot_offset = i % embed_dim;
|
66 |
+
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
67 |
+
sin_ptr, rot_offset, embed_dim);
|
68 |
+
}
|
69 |
+
|
70 |
+
const int nk = num_kv_heads * embed_dim;
|
71 |
+
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
72 |
+
const int head_idx = i / embed_dim;
|
73 |
+
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
74 |
+
const int rot_offset = i % embed_dim;
|
75 |
+
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
76 |
+
sin_ptr, rot_offset, embed_dim);
|
77 |
+
}
|
78 |
+
}
|
79 |
+
|
80 |
+
} // namespace vllm
|
81 |
+
|
82 |
+
void rotary_embedding(
|
83 |
+
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
84 |
+
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
85 |
+
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
86 |
+
int head_size,
|
87 |
+
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
88 |
+
bool is_neox) {
|
89 |
+
int64_t num_tokens = query.numel() / query.size(-1);
|
90 |
+
int rot_dim = cos_sin_cache.size(1);
|
91 |
+
int num_heads = query.size(-1) / head_size;
|
92 |
+
int num_kv_heads = key.size(-1) / head_size;
|
93 |
+
int64_t query_stride = query.stride(-2);
|
94 |
+
int64_t key_stride = key.stride(-2);
|
95 |
+
|
96 |
+
dim3 grid(num_tokens);
|
97 |
+
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
98 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
99 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
100 |
+
VLLM_DISPATCH_FLOATING_TYPES(
|
101 |
+
query.scalar_type(),
|
102 |
+
"rotary_embedding",
|
103 |
+
[&] {
|
104 |
+
if (is_neox) {
|
105 |
+
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
106 |
+
positions.data_ptr<int64_t>(),
|
107 |
+
query.data_ptr<scalar_t>(),
|
108 |
+
key.data_ptr<scalar_t>(),
|
109 |
+
cos_sin_cache.data_ptr<scalar_t>(),
|
110 |
+
rot_dim,
|
111 |
+
query_stride,
|
112 |
+
key_stride,
|
113 |
+
num_heads,
|
114 |
+
num_kv_heads,
|
115 |
+
head_size);
|
116 |
+
} else {
|
117 |
+
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
118 |
+
positions.data_ptr<int64_t>(),
|
119 |
+
query.data_ptr<scalar_t>(),
|
120 |
+
key.data_ptr<scalar_t>(),
|
121 |
+
cos_sin_cache.data_ptr<scalar_t>(),
|
122 |
+
rot_dim,
|
123 |
+
query_stride,
|
124 |
+
key_stride,
|
125 |
+
num_heads,
|
126 |
+
num_kv_heads,
|
127 |
+
head_size);
|
128 |
+
}
|
129 |
+
});
|
130 |
+
}
|
csrc/punica/LICENSE
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Contains code from https://github.com/punica-ai/punica
|
2 |
+
|
3 |
+
Apache License
|
4 |
+
Version 2.0, January 2004
|
5 |
+
http://www.apache.org/licenses/
|
6 |
+
|
7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
8 |
+
|
9 |
+
1. Definitions.
|
10 |
+
|
11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
13 |
+
|
14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
15 |
+
the copyright owner that is granting the License.
|
16 |
+
|
17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
18 |
+
other entities that control, are controlled by, or are under common
|
19 |
+
control with that entity. For the purposes of this definition,
|
20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
21 |
+
direction or management of such entity, whether by contract or
|
22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
26 |
+
exercising permissions granted by this License.
|
27 |
+
|
28 |
+
"Source" form shall mean the preferred form for making modifications,
|
29 |
+
including but not limited to software source code, documentation
|
30 |
+
source, and configuration files.
|
31 |
+
|
32 |
+
"Object" form shall mean any form resulting from mechanical
|
33 |
+
transformation or translation of a Source form, including but
|
34 |
+
not limited to compiled object code, generated documentation,
|
35 |
+
and conversions to other media types.
|
36 |
+
|
37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
38 |
+
Object form, made available under the License, as indicated by a
|
39 |
+
copyright notice that is included in or attached to the work
|
40 |
+
(an example is provided in the Appendix below).
|
41 |
+
|
42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
43 |
+
form, that is based on (or derived from) the Work and for which the
|
44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
46 |
+
of this License, Derivative Works shall not include works that remain
|
47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
48 |
+
the Work and Derivative Works thereof.
|
49 |
+
|
50 |
+
"Contribution" shall mean any work of authorship, including
|
51 |
+
the original version of the Work and any modifications or additions
|
52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
56 |
+
means any form of electronic, verbal, or written communication sent
|
57 |
+
to the Licensor or its representatives, including but not limited to
|
58 |
+
communication on electronic mailing lists, source code control systems,
|
59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
61 |
+
excluding communication that is conspicuously marked or otherwise
|
62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
63 |
+
|
64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
66 |
+
subsequently incorporated within the Work.
|
67 |
+
|
68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
73 |
+
Work and such Derivative Works in Source or Object form.
|
74 |
+
|
75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
78 |
+
(except as stated in this section) patent license to make, have made,
|
79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
80 |
+
where such license applies only to those patent claims licensable
|
81 |
+
by such Contributor that are necessarily infringed by their
|
82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
84 |
+
institute patent litigation against any entity (including a
|
85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
86 |
+
or a Contribution incorporated within the Work constitutes direct
|
87 |
+
or contributory patent infringement, then any patent licenses
|
88 |
+
granted to You under this License for that Work shall terminate
|
89 |
+
as of the date such litigation is filed.
|
90 |
+
|
91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
92 |
+
Work or Derivative Works thereof in any medium, with or without
|
93 |
+
modifications, and in Source or Object form, provided that You
|
94 |
+
meet the following conditions:
|
95 |
+
|
96 |
+
(a) You must give any other recipients of the Work or
|
97 |
+
Derivative Works a copy of this License; and
|
98 |
+
|
99 |
+
(b) You must cause any modified files to carry prominent notices
|
100 |
+
stating that You changed the files; and
|
101 |
+
|
102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
103 |
+
that You distribute, all copyright, patent, trademark, and
|
104 |
+
attribution notices from the Source form of the Work,
|
105 |
+
excluding those notices that do not pertain to any part of
|
106 |
+
the Derivative Works; and
|
107 |
+
|
108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
109 |
+
distribution, then any Derivative Works that You distribute must
|
110 |
+
include a readable copy of the attribution notices contained
|
111 |
+
within such NOTICE file, excluding those notices that do not
|
112 |
+
pertain to any part of the Derivative Works, in at least one
|
113 |
+
of the following places: within a NOTICE text file distributed
|
114 |
+
as part of the Derivative Works; within the Source form or
|
115 |
+
documentation, if provided along with the Derivative Works; or,
|
116 |
+
within a display generated by the Derivative Works, if and
|
117 |
+
wherever such third-party notices normally appear. The contents
|
118 |
+
of the NOTICE file are for informational purposes only and
|
119 |
+
do not modify the License. You may add Your own attribution
|
120 |
+
notices within Derivative Works that You distribute, alongside
|
121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
122 |
+
that such additional attribution notices cannot be construed
|
123 |
+
as modifying the License.
|
124 |
+
|
125 |
+
You may add Your own copyright statement to Your modifications and
|
126 |
+
may provide additional or different license terms and conditions
|
127 |
+
for use, reproduction, or distribution of Your modifications, or
|
128 |
+
for any such Derivative Works as a whole, provided Your use,
|
129 |
+
reproduction, and distribution of the Work otherwise complies with
|
130 |
+
the conditions stated in this License.
|
131 |
+
|
132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
134 |
+
by You to the Licensor shall be under the terms and conditions of
|
135 |
+
this License, without any additional terms or conditions.
|
136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
137 |
+
the terms of any separate license agreement you may have executed
|
138 |
+
with Licensor regarding such Contributions.
|
139 |
+
|
140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
142 |
+
except as required for reasonable and customary use in describing the
|
143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
144 |
+
|
145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
146 |
+
agreed to in writing, Licensor provides the Work (and each
|
147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
149 |
+
implied, including, without limitation, any warranties or conditions
|
150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
152 |
+
appropriateness of using or redistributing the Work and assume any
|
153 |
+
risks associated with Your exercise of permissions under this License.
|
154 |
+
|
155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
156 |
+
whether in tort (including negligence), contract, or otherwise,
|
157 |
+
unless required by applicable law (such as deliberate and grossly
|
158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
159 |
+
liable to You for damages, including any direct, indirect, special,
|
160 |
+
incidental, or consequential damages of any character arising as a
|
161 |
+
result of this License or out of the use or inability to use the
|
162 |
+
Work (including but not limited to damages for loss of goodwill,
|
163 |
+
work stoppage, computer failure or malfunction, or any and all
|
164 |
+
other commercial damages or losses), even if such Contributor
|
165 |
+
has been advised of the possibility of such damages.
|
166 |
+
|
167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
170 |
+
or other liability obligations and/or rights consistent with this
|
171 |
+
License. However, in accepting such obligations, You may act only
|
172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
173 |
+
of any other Contributor, and only if You agree to indemnify,
|
174 |
+
defend, and hold each Contributor harmless for any liability
|
175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
176 |
+
of your accepting any such warranty or additional liability.
|
177 |
+
|
178 |
+
END OF TERMS AND CONDITIONS
|
179 |
+
|
180 |
+
APPENDIX: How to apply the Apache License to your work.
|
181 |
+
|
182 |
+
To apply the Apache License to your work, attach the following
|
183 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
184 |
+
replaced with your own identifying information. (Don't include
|
185 |
+
the brackets!) The text should be enclosed in the appropriate
|
186 |
+
comment syntax for the file format. We also recommend that a
|
187 |
+
file or class name and description of purpose be included on the
|
188 |
+
same "printed page" as the copyright notice for easier
|
189 |
+
identification within third-party archives.
|
190 |
+
|
191 |
+
Copyright {yyyy} {name of copyright owner}
|
192 |
+
|
193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
194 |
+
you may not use this file except in compliance with the License.
|
195 |
+
You may obtain a copy of the License at
|
196 |
+
|
197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
198 |
+
|
199 |
+
Unless required by applicable law or agreed to in writing, software
|
200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
202 |
+
See the License for the specific language governing permissions and
|
203 |
+
limitations under the License.
|
204 |
+
|
205 |
+
------------------------------------------------------------------------------------
|
206 |
+
|
207 |
+
This product bundles various third-party components under other open source licenses.
|
208 |
+
This section summarizes those components and their licenses. See licenses/
|
209 |
+
for text of these licenses.
|
210 |
+
|
211 |
+
|
212 |
+
Apache-2.0
|
213 |
+
* third_party/nvbench (with LLVM exception)
|
214 |
+
* third_party/flashinfer
|
215 |
+
|
216 |
+
BSD-3-Clause:
|
217 |
+
* third_party/cutlass
|
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "bgmv_config.h"
|
2 |
+
#include "bgmv_impl.cuh"
|
3 |
+
|
4 |
+
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|