diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh
new file mode 100644
index 0000000000000000000000000000000000000000..dde28cb55605efaacd7705fee5e8e08f34588342
--- /dev/null
+++ b/.buildkite/run-benchmarks.sh
@@ -0,0 +1,63 @@
+# This script is run by buildkite to run the benchmarks and upload the results to buildkite
+
+set -ex
+set -o pipefail
+
+# cd into parent directory of this file
+cd "$(dirname "${BASH_SOURCE[0]}")/.."
+
+(wget && curl) || (apt-get update && apt-get install -y wget curl)
+
+# run benchmarks and upload the result to buildkite
+python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
+bench_latency_exit_code=$?
+
+python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
+bench_throughput_exit_code=$?
+
+python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
+server_pid=$!
+wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
+
+# wait for server to start, timeout after 600 seconds
+timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
+python3 benchmarks/benchmark_serving.py \
+ --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
+ --model meta-llama/Llama-2-7b-chat-hf \
+ --num-prompts 20 \
+ --endpoint /v1/completions \
+ --tokenizer meta-llama/Llama-2-7b-chat-hf 2>&1 | tee benchmark_serving.txt
+bench_serving_exit_code=$?
+kill $server_pid
+
+# write the results into a markdown file
+echo "### Latency Benchmarks" >> benchmark_results.md
+sed -n '1p' benchmark_latency.txt >> benchmark_results.md # first line
+echo "" >> benchmark_results.md
+sed -n '$p' benchmark_latency.txt >> benchmark_results.md # last line
+
+echo "### Throughput Benchmarks" >> benchmark_results.md
+sed -n '1p' benchmark_throughput.txt >> benchmark_results.md # first line
+echo "" >> benchmark_results.md
+sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line
+
+echo "### Serving Benchmarks" >> benchmark_results.md
+sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
+echo "" >> benchmark_results.md
+tail -n 5 benchmark_serving.txt >> benchmark_results.md # last 5 lines
+
+# upload the results to buildkite
+/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
+
+# exit with the exit code of the benchmarks
+if [ $bench_latency_exit_code -ne 0 ]; then
+ exit $bench_latency_exit_code
+fi
+
+if [ $bench_throughput_exit_code -ne 0 ]; then
+ exit $bench_throughput_exit_code
+fi
+
+if [ $bench_serving_exit_code -ne 0 ]; then
+ exit $bench_serving_exit_code
+fi
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..65ac2f74fb8dc9de7ae1b9b6de6c2c43207ff285
--- /dev/null
+++ b/.buildkite/test-pipeline.yaml
@@ -0,0 +1,51 @@
+# In this file, you can add more tests to run either by adding a new step or
+# adding a new command to an existing step. See different options here for examples.
+# This script will be feed into Jinja template in `test-template.j2` to generate
+# the final pipeline yaml file.
+
+steps:
+- label: Regression Test
+ command: pytest -v -s test_regression.py
+ working_dir: "/vllm-workspace/tests" # optional
+
+- label: AsyncEngine Test
+ command: pytest -v -s async_engine
+
+- label: Distributed Test
+ command: pytest -v -s test_comm_ops.py
+ working_dir: "/vllm-workspace/tests/distributed"
+ num_gpus: 2 # only support 1 or 2 for now.
+
+- label: Engine Test
+ command: pytest -v -s engine
+
+- label: Entrypoints Test
+ command: pytest -v -s entrypoints
+
+- label: Kernels Test
+ command: pytest -v -s kernels
+ soft_fail: true
+
+- label: Models Test
+ commands:
+ - pytest -v -s models --forked
+ soft_fail: true
+
+- label: Prefix Caching Test
+ commands:
+ - pytest -v -s prefix_caching
+
+- label: Samplers Test
+ command: pytest -v -s samplers --forked
+
+- label: Worker Test
+ command: pytest -v -s worker
+
+- label: LoRA Test
+ command: pytest -v -s lora
+
+- label: Benchmarks
+ working_dir: "/vllm-workspace/.buildkite"
+ commands:
+ - pip install aiohttp
+ - bash run-benchmarks.sh
diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2
new file mode 100644
index 0000000000000000000000000000000000000000..7c709b6097fd4c5b07d1a30fc22787cc625cec45
--- /dev/null
+++ b/.buildkite/test-template.j2
@@ -0,0 +1,54 @@
+{% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %}
+{% set default_num_gpu = 1 %}
+{% set default_working_dir = "/vllm-workspace/tests" %}
+
+steps:
+ - label: ":docker: build image"
+ commands:
+ - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
+ - "docker push {{ docker_image }}"
+ env:
+ DOCKER_BUILDKIT: "1"
+ retry:
+ automatic:
+ - exit_status: -1 # Agent was lost
+ limit: 5
+ - wait
+
+ {% for step in steps %}
+ - label: "{{ step.label }}"
+ agents:
+ queue: kubernetes
+ soft_fail: {{ step.soft_fail or false }}
+ retry:
+ automatic:
+ - exit_status: -1 # Agent was lost
+ limit: 5
+ plugins:
+ - kubernetes:
+ podSpec:
+ volumes:
+ - name: dshm
+ emptyDir:
+ medium: Memory
+ containers:
+ - image: "{{ docker_image }}"
+ command: ["bash"]
+ args:
+ - "-c"
+ - "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'"
+ resources:
+ requests:
+ nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
+ limits:
+ nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
+ env:
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ volumeMounts:
+ - mountPath: /dev/shm
+ name: dshm
+ {% endfor %}
diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000000000000000000000000000000000000..5cfe0dcb065dc88226d754e98abbaa52125b1ccc
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1 @@
+vllm/*.so
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
new file mode 100644
index 0000000000000000000000000000000000000000..5211dc180798eccd97e56964a340e9028bee8104
--- /dev/null
+++ b/.github/workflows/publish.yml
@@ -0,0 +1,102 @@
+# This workflow will upload a Python Package to Release asset
+# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
+
+name: Create Release
+
+on:
+ push:
+ tags:
+ - v*
+
+# Needed to create release and upload assets
+permissions:
+ contents: write
+
+jobs:
+ release:
+ # Retrieve tag and create release
+ name: Create Release
+ runs-on: ubuntu-latest
+ outputs:
+ upload_url: ${{ steps.create_release.outputs.upload_url }}
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+
+ - name: Extract branch info
+ shell: bash
+ run: |
+ echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
+
+ - name: Create Release
+ id: create_release
+ uses: "actions/github-script@v6"
+ env:
+ RELEASE_TAG: ${{ env.release_tag }}
+ with:
+ github-token: "${{ secrets.GITHUB_TOKEN }}"
+ script: |
+ const script = require('.github/workflows/scripts/create_release.js')
+ await script(github, context, core)
+
+ wheel:
+ name: Build Wheel
+ runs-on: ${{ matrix.os }}
+ needs: release
+
+ strategy:
+ fail-fast: false
+ matrix:
+ os: ['ubuntu-20.04']
+ python-version: ['3.8', '3.9', '3.10', '3.11']
+ pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt.
+ cuda-version: ['11.8', '12.1']
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+
+ - name: Set up Linux Env
+ if: ${{ runner.os == 'Linux' }}
+ run: |
+ bash -x .github/workflows/scripts/env.sh
+
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install CUDA ${{ matrix.cuda-version }}
+ run: |
+ bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
+
+ - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
+ run: |
+ bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
+
+ - name: Build wheel
+ shell: bash
+ run: |
+ bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
+ wheel_name=$(ls dist/*whl | xargs -n 1 basename)
+ asset_name=${wheel_name//"linux"/"manylinux1"}
+ echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
+ echo "asset_name=${asset_name}" >> $GITHUB_ENV
+
+ - name: Upload Release Asset
+ uses: actions/upload-release-asset@v1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ upload_url: ${{ needs.release.outputs.upload_url }}
+ asset_path: ./dist/${{ env.wheel_name }}
+ asset_name: ${{ env.asset_name }}
+ asset_content_type: application/*
+
+ # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
+ # - name: Publish package
+ # uses: pypa/gh-action-pypi-publish@release/v1.8
+ # with:
+ # repository-url: https://test.pypi.org/legacy/
+ # password: ${{ secrets.PYPI_API_TOKEN }}
+ # skip-existing: true
diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml
new file mode 100644
index 0000000000000000000000000000000000000000..bd38d11872dc4789899f885bc82cadfc65261534
--- /dev/null
+++ b/.github/workflows/ruff.yml
@@ -0,0 +1,31 @@
+name: ruff
+
+on:
+ # Trigger the workflow on push or pull request,
+ # but only for the main branch
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ ruff:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.10"]
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install ruff==0.1.5
+ - name: Analysing the code with ruff
+ run: |
+ ruff vllm tests
diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2578d448436d2d922bc179ac8fa238ff94a93156
--- /dev/null
+++ b/.github/workflows/scripts/build.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+python_executable=python$1
+cuda_home=/usr/local/cuda-$2
+
+# Update paths
+PATH=${cuda_home}/bin:$PATH
+LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
+
+# Install requirements
+$python_executable -m pip install wheel packaging
+$python_executable -m pip install -r requirements.txt
+
+# Limit the number of parallel jobs to avoid OOM
+export MAX_JOBS=1
+# Make sure punica is built for the release (for LoRA)
+export VLLM_INSTALL_PUNICA_KERNELS=1
+
+# Build
+$python_executable setup.py bdist_wheel --dist-dir=dist
diff --git a/.github/workflows/scripts/create_release.js b/.github/workflows/scripts/create_release.js
new file mode 100644
index 0000000000000000000000000000000000000000..0f25624b4c21c6ba84c66809d5e87a18a29c44c2
--- /dev/null
+++ b/.github/workflows/scripts/create_release.js
@@ -0,0 +1,20 @@
+// Uses Github's API to create the release and wait for result.
+// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
+
+module.exports = async (github, context, core) => {
+ try {
+ const response = await github.rest.repos.createRelease({
+ draft: false,
+ generate_release_notes: true,
+ name: process.env.RELEASE_TAG,
+ owner: context.repo.owner,
+ prerelease: false,
+ repo: context.repo.repo,
+ tag_name: process.env.RELEASE_TAG,
+ });
+
+ core.setOutput('upload_url', response.data.upload_url);
+ } catch (error) {
+ core.setFailed(error.message);
+ }
+}
\ No newline at end of file
diff --git a/.github/workflows/scripts/cuda-install.sh b/.github/workflows/scripts/cuda-install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..312c6e82f33a3231f3e6516b3a9ca210d004ef92
--- /dev/null
+++ b/.github/workflows/scripts/cuda-install.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+# Replace '.' with '-' ex: 11.8 -> 11-8
+cuda_version=$(echo $1 | tr "." "-")
+# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
+OS=$(echo $2 | tr -d ".\-")
+
+# Installs CUDA
+wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
+sudo dpkg -i cuda-keyring_1.1-1_all.deb
+rm cuda-keyring_1.1-1_all.deb
+sudo apt -qq update
+sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
+sudo apt clean
+
+# Test nvcc
+PATH=/usr/local/cuda-$1/bin:${PATH}
+nvcc --version
+
+# Log gcc, g++, c++ versions
+gcc --version
+g++ --version
+c++ --version
diff --git a/.github/workflows/scripts/env.sh b/.github/workflows/scripts/env.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d7baaecbbc7544bdac802d832aefaec357863125
--- /dev/null
+++ b/.github/workflows/scripts/env.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+# This file installs common linux environment tools
+
+export LANG C.UTF-8
+
+# python_version=$1
+
+sudo apt-get update && \
+sudo apt-get install -y --no-install-recommends \
+ software-properties-common \
+
+sudo apt-get install -y --no-install-recommends \
+ build-essential \
+ apt-utils \
+ ca-certificates \
+ wget \
+ git \
+ vim \
+ libssl-dev \
+ curl \
+ unzip \
+ unrar \
+ cmake \
+ net-tools \
+ sudo \
+ autotools-dev \
+ rsync \
+ jq \
+ openssh-server \
+ tmux \
+ screen \
+ htop \
+ pdsh \
+ openssh-client \
+ lshw \
+ dmidecode \
+ util-linux \
+ automake \
+ autoconf \
+ libtool \
+ net-tools \
+ pciutils \
+ libpci-dev \
+ libaio-dev \
+ libcap2 \
+ libtinfo5 \
+ fakeroot \
+ devscripts \
+ debhelper \
+ nfs-common
+
+# Remove github bloat files to free up disk space
+sudo rm -rf "/usr/local/share/boost"
+sudo rm -rf "$AGENT_TOOLSDIRECTORY"
+sudo rm -rf "/usr/share/dotnet"
diff --git a/.github/workflows/scripts/pytorch-install.sh b/.github/workflows/scripts/pytorch-install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..dfc1851d7692cd6e2919b534e7c5eac441c732b9
--- /dev/null
+++ b/.github/workflows/scripts/pytorch-install.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+python_executable=python$1
+pytorch_version=$2
+cuda_version=$3
+
+# Install torch
+$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
+$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --extra-index-url https://download.pytorch.org/whl/cu${cuda_version//./}
+
+# Print version information
+$python_executable --version
+$python_executable -c "import torch; print('PyTorch:', torch.__version__)"
+$python_executable -c "import torch; print('CUDA:', torch.version.cuda)"
+$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b163c960db555b5edb1d1a287f5a764becb762cf
--- /dev/null
+++ b/.github/workflows/yapf.yml
@@ -0,0 +1,31 @@
+name: yapf
+
+on:
+ # Trigger the workflow on push or pull request,
+ # but only for the main branch
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+jobs:
+ yapf:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.10"]
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install yapf==0.32.0
+ pip install toml==0.10.2
+ - name: Running yapf
+ run: |
+ yapf --diff --recursive .
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..b5195629e5cf3727397c46008a67dde6a9f9d00a
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,186 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+.idea/
+
+# VSCode
+.vscode/
+
+# DS Store
+.DS_Store
+
+# Results
+*.csv
+
+# Python pickle files
+*.pkl
+
+# Sphinx documentation
+_build/
+
+# vim swap files
+*.swo
+*.swp
+
+# hip files generated by PyTorch
+*.hip
+*_hip*
+
+# Benchmark dataset
+*.json
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..428e19908858989f35c0cb8ae1986b5114b25f7c
--- /dev/null
+++ b/.readthedocs.yaml
@@ -0,0 +1,21 @@
+# Read the Docs configuration file
+# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
+
+version: 2
+
+build:
+ os: ubuntu-22.04
+ tools:
+ python: "3.8"
+
+sphinx:
+ configuration: docs/source/conf.py
+
+# If using Sphinx, optionally build your docs in additional formats such as PDF
+formats:
+ - pdf
+
+# Optionally declare the Python requirements required to build your docs
+python:
+ install:
+ - requirements: docs/requirements-docs.txt
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..93a4de73faa89478c0968434313e03cbfe950032
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,77 @@
+# Contributing to vLLM
+
+Thank you for your interest in contributing to vLLM!
+Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large.
+There are several ways you can contribute to the project:
+
+- Identify and report any issues or bugs.
+- Request or add a new model.
+- Suggest or implement new features.
+
+However, remember that contributions aren't just about code.
+We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions.
+
+Finally, one of the most impactful ways to support us is by raising awareness about vLLM.
+Talk about it in your blog posts, highlighting how it's driving your incredible projects.
+Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository.
+
+
+## Setup for development
+
+### Build from source
+
+```bash
+pip install -r requirements.txt
+pip install -e . # This may take several minutes.
+```
+
+### Testing
+
+```bash
+pip install -r requirements-dev.txt
+
+# Static type checking
+mypy
+# Unit tests
+pytest tests/
+```
+**Note:** Currently, the repository does not pass the mypy tests.
+
+
+## Contributing Guidelines
+
+### Issue Reporting
+
+If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it.
+If not, please file a new issue, providing as much relevant information as possible.
+
+### Coding Style Guide
+
+In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
+
+We include a formatting script [`format.sh`](./format.sh) to format the code.
+
+### Pull Requests
+
+When submitting a pull request:
+
+1. Make sure your code has been rebased on top of the latest commit on the main branch.
+2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
+3. Include a detailed description of the changes in the pull request.
+Explain why you made the changes you did.
+If your pull request fixes an open issue, please include a reference to it in the description.
+
+### Code Reviews
+
+All submissions, including submissions by project members, require a code review.
+To make the review process as smooth as possible, please:
+
+1. Keep your changes as concise as possible.
+If your pull request involves multiple unrelated changes, consider splitting it into separate pull requests.
+2. Respond to all comments within a reasonable time frame.
+If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
+
+### Thank You
+
+Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM.
+Your contributions make vLLM a great tool for everyone!
diff --git a/Dockerfile b/Dockerfile
index 8a3e25f736d36743219eb39a8dd6e0893fd6fc3c..4cfcf058004c57e291e1aa8de31fa7f32840f16c 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -94,4 +94,4 @@ COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
-#################### OPENAI API SERVER ####################
\ No newline at end of file
+#################### OPENAI API SERVER ####################
diff --git a/Dockerfile.rocm b/Dockerfile.rocm
new file mode 100644
index 0000000000000000000000000000000000000000..88172fb73b937828a111492997ad96b66cad403c
--- /dev/null
+++ b/Dockerfile.rocm
@@ -0,0 +1,88 @@
+# default base image
+ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
+
+FROM $BASE_IMAGE
+
+ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
+
+RUN echo "Base image is $BASE_IMAGE"
+
+# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
+# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
+
+# this does not always work for all rocm versions
+RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
+ echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"
+
+ARG FA_GFX_ARCHS="gfx90a;gfx942"
+RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
+
+ARG FA_BRANCH="3d2b6f5"
+RUN echo "FA_BRANCH is $FA_BRANCH"
+
+# Install some basic utilities
+RUN apt-get update && apt-get install python3 python3-pip -y
+
+# Install some basic utilities
+RUN apt-get update && apt-get install -y \
+ curl \
+ ca-certificates \
+ sudo \
+ git \
+ bzip2 \
+ libx11-6 \
+ build-essential \
+ wget \
+ unzip \
+ nvidia-cuda-toolkit \
+ tmux \
+ && rm -rf /var/lib/apt/lists/*
+
+### Mount Point ###
+# When launching the container, mount the code directory to /app
+ARG APP_MOUNT=/app
+VOLUME [ ${APP_MOUNT} ]
+WORKDIR ${APP_MOUNT}
+
+RUN python3 -m pip install --upgrade pip
+RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
+
+ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
+ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
+ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
+ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
+
+# Install ROCm flash-attention
+RUN mkdir libs \
+ && cd libs \
+ && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
+ && cd flash-attention \
+ && git checkout ${FA_BRANCH} \
+ && git submodule update --init \
+ && export GPU_ARCHS=${FA_GFX_ARCHS} \
+ && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
+ patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
+ && python3 setup.py install \
+ && cd ..
+
+COPY ./ /app/vllm
+
+RUN python3 -m pip install --upgrade pip
+RUN python3 -m pip install xformers==0.0.23 --no-deps
+
+# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
+# Manually removed it so that later steps of numpy upgrade can continue
+RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
+ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
+
+RUN cd /app \
+ && cd vllm \
+ && pip install -U -r requirements-rocm.txt \
+ && bash patch_xformers.rocm.sh \
+ && python3 setup.py install \
+ && cd ..
+
+RUN python3 -m pip install --upgrade pip
+RUN python3 -m pip install --no-cache-dir ray[all]
+
+CMD ["/bin/bash"]
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..0c897cf147f109d6a452905acfd006934fa495dc
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,4 @@
+include LICENSE
+include requirements.txt
+
+recursive-include csrc *
diff --git a/README.md b/README.md
index 5bc0cb80eab8ff973042a32284ee5c135cfff7b9..c0d267b2cbbf3d3488ec244f7d265f32a761c503 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,113 @@
+
+
+
+
+
+
+
+
+Easy, fast, and cheap LLM serving for everyone
+
+
+
+| Documentation | Blog | Paper | Discord |
+
+
+
+---
+
+**The Second vLLM Bay Area Meetup (Jan 31st 5pm-7:30pm PT)**
+
+We are thrilled to announce our second vLLM Meetup!
+The vLLM team will share recent updates and roadmap.
+We will also have vLLM collaborators from IBM coming up to the stage to discuss their insights on LLM optimizations.
+Please register [here](https://lu.ma/ygxbpzhl) and join us!
+
---
-title: Certifaier
-emoji: 🏆
-colorFrom: indigo
-colorTo: red
-sdk: docker
-pinned: false
+
+*Latest News* 🔥
+- [2024/01] Added ROCm 6.0 support to vLLM.
+- [2023/12] Added ROCm 5.7 support to vLLM.
+- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
+- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
+- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
+- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
+- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
+- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
+- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
+
---
+## About
+vLLM is a fast and easy-to-use library for LLM inference and serving.
+
+vLLM is fast with:
+
+- State-of-the-art serving throughput
+- Efficient management of attention key and value memory with **PagedAttention**
+- Continuous batching of incoming requests
+- Fast model execution with CUDA/HIP graph
+- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629)
+- Optimized CUDA kernels
+
+vLLM is flexible and easy to use with:
+
+- Seamless integration with popular Hugging Face models
+- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
+- Tensor parallelism support for distributed inference
+- Streaming outputs
+- OpenAI-compatible API server
+- Support NVIDIA GPUs and AMD GPUs
+
+vLLM seamlessly supports many Hugging Face models, including the following architectures:
+
+- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
+- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
+- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
+- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
+- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
+- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
+- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
+- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
+- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
+- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
+- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
+- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
+- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
+- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
+- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
+- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
+- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
+- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
+- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
+- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
+- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
+
+Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
+
+```bash
+pip install vllm
+```
+
+## Getting Started
+
+Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
+- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
+- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
+- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
+
+## Contributing
+
+We welcome and value any contributions and collaborations.
+Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
+
+## Citation
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
+```bibtex
+@inproceedings{kwon2023efficient,
+ title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
+ author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
+ booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
+ year={2023}
+}
+```
diff --git a/benchmarks/README.md b/benchmarks/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..192d6c4022c839f4d8c459d572f8fa23b6ac0968
--- /dev/null
+++ b/benchmarks/README.md
@@ -0,0 +1,8 @@
+# Benchmarking vLLM
+
+## Downloading the ShareGPT dataset
+
+You can download the dataset by running:
+```bash
+wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
+```
diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py
new file mode 100644
index 0000000000000000000000000000000000000000..7173134358762999dd60303842018698ef554b5a
--- /dev/null
+++ b/benchmarks/benchmark_latency.py
@@ -0,0 +1,139 @@
+"""Benchmark the latency of processing a single batch of requests."""
+import argparse
+import time
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from vllm import LLM, SamplingParams
+
+
+def main(args: argparse.Namespace):
+ print(args)
+
+ # NOTE(woosuk): If the request cannot be processed in a single batch,
+ # the engine will automatically process the request in multiple batches.
+ llm = LLM(
+ model=args.model,
+ tokenizer=args.tokenizer,
+ quantization=args.quantization,
+ tensor_parallel_size=args.tensor_parallel_size,
+ trust_remote_code=args.trust_remote_code,
+ dtype=args.dtype,
+ enforce_eager=args.enforce_eager,
+ kv_cache_dtype=args.kv_cache_dtype,
+ )
+
+ sampling_params = SamplingParams(
+ n=args.n,
+ temperature=0.0 if args.use_beam_search else 1.0,
+ top_p=1.0,
+ use_beam_search=args.use_beam_search,
+ ignore_eos=True,
+ max_tokens=args.output_len,
+ )
+ print(sampling_params)
+ dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
+
+ def run_to_completion(profile_dir: Optional[str] = None):
+ if profile_dir:
+ with torch.profiler.profile(
+ activities=[
+ torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.CUDA,
+ ],
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(
+ str(profile_dir))) as p:
+ llm.generate(prompt_token_ids=dummy_prompt_token_ids,
+ sampling_params=sampling_params,
+ use_tqdm=False)
+ print(p.key_averages())
+ else:
+ start_time = time.perf_counter()
+ llm.generate(prompt_token_ids=dummy_prompt_token_ids,
+ sampling_params=sampling_params,
+ use_tqdm=False)
+ end_time = time.perf_counter()
+ latency = end_time - start_time
+ return latency
+
+ print("Warming up...")
+ run_to_completion(profile_dir=None)
+
+ if args.profile:
+ profile_dir = args.profile_result_dir
+ if not profile_dir:
+ profile_dir = Path(
+ "."
+ ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
+ print(f"Profiling (results will be saved to '{profile_dir}')...")
+ run_to_completion(profile_dir=args.profile_result_dir)
+ return
+
+ # Benchmark.
+ latencies = []
+ for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
+ latencies.append(run_to_completion(profile_dir=None))
+ print(f'Avg latency: {np.mean(latencies)} seconds')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description='Benchmark the latency of processing a single batch of '
+ 'requests till completion.')
+ parser.add_argument('--model', type=str, default='facebook/opt-125m')
+ parser.add_argument('--tokenizer', type=str, default=None)
+ parser.add_argument('--quantization',
+ '-q',
+ choices=['awq', 'gptq', 'squeezellm', None],
+ default=None)
+ parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
+ parser.add_argument('--input-len', type=int, default=32)
+ parser.add_argument('--output-len', type=int, default=128)
+ parser.add_argument('--batch-size', type=int, default=8)
+ parser.add_argument('--n',
+ type=int,
+ default=1,
+ help='Number of generated sequences per prompt.')
+ parser.add_argument('--use-beam-search', action='store_true')
+ parser.add_argument('--num-iters',
+ type=int,
+ default=3,
+ help='Number of iterations to run.')
+ parser.add_argument('--trust-remote-code',
+ action='store_true',
+ help='trust remote code from huggingface')
+ parser.add_argument(
+ '--dtype',
+ type=str,
+ default='auto',
+ choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
+ help='data type for model weights and activations. '
+ 'The "auto" option will use FP16 precision '
+ 'for FP32 and FP16 models, and BF16 precision '
+ 'for BF16 models.')
+ parser.add_argument('--enforce-eager',
+ action='store_true',
+ help='enforce eager mode and disable CUDA graph')
+ parser.add_argument(
+ "--kv-cache-dtype",
+ type=str,
+ choices=['auto', 'fp8_e5m2'],
+ default='auto',
+ help=
+ 'Data type for kv cache storage. If "auto", will use model data type.')
+ parser.add_argument(
+ '--profile',
+ action='store_true',
+ help='profile the generation process of a single batch')
+ parser.add_argument(
+ '--profile-result-dir',
+ type=str,
+ default=None,
+ help=('path to save the pytorch profiler output. Can be visualized '
+ 'with ui.perfetto.dev or Tensorboard.'))
+ args = parser.parse_args()
+ main(args)
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a36d9d6a5debc18e75214c1ee4fd47d457f1a3e
--- /dev/null
+++ b/benchmarks/benchmark_serving.py
@@ -0,0 +1,249 @@
+"""Benchmark online serving throughput.
+
+On the server side, run one of the following commands:
+ (vLLM backend)
+ python -m vllm.entrypoints.api_server \
+ --model --swap-space 16 \
+ --disable-log-requests
+
+ (TGI backend)
+ ./launch_hf_server.sh
+
+On the client side, run:
+ python benchmarks/benchmark_serving.py \
+ --backend \
+ --tokenizer --dataset \
+ --request-rate
+"""
+import argparse
+import asyncio
+import json
+import random
+import time
+from typing import AsyncGenerator, List, Tuple
+
+import aiohttp
+import numpy as np
+from tqdm.asyncio import tqdm
+from transformers import PreTrainedTokenizerBase
+from vllm.transformers_utils.tokenizer import get_tokenizer
+
+# (prompt len, output len, latency)
+REQUEST_LATENCY: List[Tuple[int, int, float]] = []
+
+
+def sample_requests(
+ dataset_path: str,
+ num_requests: int,
+ tokenizer: PreTrainedTokenizerBase,
+) -> List[Tuple[str, int, int]]:
+ # Load the dataset.
+ with open(dataset_path) as f:
+ dataset = json.load(f)
+ # Filter out the conversations with less than 2 turns.
+ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
+ # Only keep the first two turns of each conversation.
+ dataset = [(data["conversations"][0]["value"],
+ data["conversations"][1]["value"]) for data in dataset]
+
+ # Tokenize the prompts and completions.
+ prompts = [prompt for prompt, _ in dataset]
+ prompt_token_ids = tokenizer(prompts).input_ids
+ completions = [completion for _, completion in dataset]
+ completion_token_ids = tokenizer(completions).input_ids
+ tokenized_dataset = []
+ for i in range(len(dataset)):
+ output_len = len(completion_token_ids[i])
+ tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
+
+ # Filter out too long sequences.
+ filtered_dataset: List[Tuple[str, int, int]] = []
+ for prompt, prompt_token_ids, output_len in tokenized_dataset:
+ prompt_len = len(prompt_token_ids)
+ if prompt_len < 4 or output_len < 4:
+ # Prune too short sequences.
+ # This is because TGI causes errors when the input or output length
+ # is too short.
+ continue
+ if prompt_len > 1024 or prompt_len + output_len > 2048:
+ # Prune too long sequences.
+ continue
+ filtered_dataset.append((prompt, prompt_len, output_len))
+
+ # Sample the requests.
+ sampled_requests = random.sample(filtered_dataset, num_requests)
+ return sampled_requests
+
+
+async def get_request(
+ input_requests: List[Tuple[str, int, int]],
+ request_rate: float,
+) -> AsyncGenerator[Tuple[str, int, int], None]:
+ input_requests = iter(input_requests)
+ for request in input_requests:
+ yield request
+
+ if request_rate == float("inf"):
+ # If the request rate is infinity, then we don't need to wait.
+ continue
+ # Sample the request interval from the exponential distribution.
+ interval = np.random.exponential(1.0 / request_rate)
+ # The next request will be sent after the interval.
+ await asyncio.sleep(interval)
+
+
+async def send_request(backend: str, model: str, api_url: str, prompt: str,
+ prompt_len: int, output_len: int, best_of: int,
+ use_beam_search: bool, pbar: tqdm) -> None:
+ request_start_time = time.perf_counter()
+
+ headers = {"User-Agent": "Benchmark Client"}
+ if backend == "vllm":
+ pload = {
+ "prompt": prompt,
+ "n": 1,
+ "best_of": best_of,
+ "use_beam_search": use_beam_search,
+ "temperature": 0.0 if use_beam_search else 1.0,
+ "top_p": 1.0,
+ "max_tokens": output_len,
+ "ignore_eos": True,
+ "stream": False,
+ }
+ if model is not None:
+ pload["model"] = model
+ elif backend == "tgi":
+ assert not use_beam_search
+ params = {
+ "best_of": best_of,
+ "max_new_tokens": output_len,
+ "do_sample": True,
+ }
+ pload = {
+ "inputs": prompt,
+ "parameters": params,
+ }
+ else:
+ raise ValueError(f"Unknown backend: {backend}")
+
+ timeout = aiohttp.ClientTimeout(total=3 * 3600)
+ async with aiohttp.ClientSession(timeout=timeout) as session:
+ while True:
+ async with session.post(api_url, headers=headers,
+ json=pload) as response:
+ chunks = []
+ async for chunk, _ in response.content.iter_chunks():
+ chunks.append(chunk)
+ output = b"".join(chunks).decode("utf-8")
+ output = json.loads(output)
+
+ # Re-send the request if it failed.
+ if "error" not in output:
+ break
+
+ request_end_time = time.perf_counter()
+ request_latency = request_end_time - request_start_time
+ REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
+ pbar.update(1)
+
+
+async def benchmark(
+ backend: str,
+ model: str,
+ api_url: str,
+ input_requests: List[Tuple[str, int, int]],
+ best_of: int,
+ use_beam_search: bool,
+ request_rate: float,
+) -> None:
+ tasks: List[asyncio.Task] = []
+ pbar = tqdm(total=len(input_requests))
+ async for request in get_request(input_requests, request_rate):
+ prompt, prompt_len, output_len = request
+ task = asyncio.create_task(
+ send_request(backend, model, api_url, prompt, prompt_len,
+ output_len, best_of, use_beam_search, pbar))
+ tasks.append(task)
+ await asyncio.gather(*tasks)
+ pbar.close()
+
+
+def main(args: argparse.Namespace):
+ print(args)
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}"
+ tokenizer = get_tokenizer(args.tokenizer,
+ trust_remote_code=args.trust_remote_code)
+ input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
+
+ benchmark_start_time = time.perf_counter()
+ asyncio.run(
+ benchmark(args.backend, args.model, api_url, input_requests,
+ args.best_of, args.use_beam_search, args.request_rate))
+ benchmark_end_time = time.perf_counter()
+ benchmark_time = benchmark_end_time - benchmark_start_time
+ print(f"Total time: {benchmark_time:.2f} s")
+ print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
+
+ # Compute the latency statistics.
+ avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
+ print(f"Average latency: {avg_latency:.2f} s")
+ avg_per_token_latency = np.mean([
+ latency / (prompt_len + output_len)
+ for prompt_len, output_len, latency in REQUEST_LATENCY
+ ])
+ print(f"Average latency per token: {avg_per_token_latency:.2f} s")
+ avg_per_output_token_latency = np.mean(
+ [latency / output_len for _, output_len, latency in REQUEST_LATENCY])
+ print("Average latency per output token: "
+ f"{avg_per_output_token_latency:.2f} s")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Benchmark the online serving throughput.")
+ parser.add_argument("--backend",
+ type=str,
+ default="vllm",
+ choices=["vllm", "tgi"])
+ parser.add_argument("--protocol",
+ type=str,
+ default="http",
+ choices=["http", "https"])
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=8000)
+ parser.add_argument("--endpoint", type=str, default="/generate")
+ parser.add_argument("--model", type=str, default=None)
+ parser.add_argument("--dataset",
+ type=str,
+ required=True,
+ help="Path to the dataset.")
+ parser.add_argument("--tokenizer",
+ type=str,
+ required=True,
+ help="Name or path of the tokenizer.")
+ parser.add_argument("--best-of",
+ type=int,
+ default=1,
+ help="Generates `best_of` sequences per prompt and "
+ "returns the best one.")
+ parser.add_argument("--use-beam-search", action="store_true")
+ parser.add_argument("--num-prompts",
+ type=int,
+ default=1000,
+ help="Number of prompts to process.")
+ parser.add_argument("--request-rate",
+ type=float,
+ default=float("inf"),
+ help="Number of requests per second. If this is inf, "
+ "then all the requests are sent at time 0. "
+ "Otherwise, we use Poisson process to synthesize "
+ "the request arrival times.")
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument('--trust-remote-code',
+ action='store_true',
+ help='trust remote code from huggingface')
+ args = parser.parse_args()
+ main(args)
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
new file mode 100644
index 0000000000000000000000000000000000000000..d45d33307c9124c365e73c30a62771ce5082d2e4
--- /dev/null
+++ b/benchmarks/benchmark_throughput.py
@@ -0,0 +1,328 @@
+"""Benchmark offline inference throughput."""
+import argparse
+import json
+import random
+import time
+from typing import List, Optional, Tuple
+
+import torch
+from transformers import (AutoModelForCausalLM, AutoTokenizer,
+ PreTrainedTokenizerBase)
+from tqdm import tqdm
+
+
+def sample_requests(
+ dataset_path: str,
+ num_requests: int,
+ tokenizer: PreTrainedTokenizerBase,
+ fixed_output_len: Optional[int],
+) -> List[Tuple[str, int, int]]:
+ if fixed_output_len is not None and fixed_output_len < 4:
+ raise ValueError("output_len too small")
+
+ # Load the dataset.
+ with open(dataset_path) as f:
+ dataset = json.load(f)
+ # Filter out the conversations with less than 2 turns.
+ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
+ # Only keep the first two turns of each conversation.
+ dataset = [(data["conversations"][0]["value"],
+ data["conversations"][1]["value"]) for data in dataset]
+
+ # Tokenize the prompts and completions.
+ prompts = [prompt for prompt, _ in dataset]
+ prompt_token_ids = tokenizer(prompts).input_ids
+ completions = [completion for _, completion in dataset]
+ completion_token_ids = tokenizer(completions).input_ids
+ tokenized_dataset = []
+ for i in range(len(dataset)):
+ output_len = len(completion_token_ids[i])
+ if fixed_output_len is not None:
+ output_len = fixed_output_len
+ tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
+
+ # Filter out too long sequences.
+ filtered_dataset: List[Tuple[str, int, int]] = []
+ for prompt, prompt_token_ids, output_len in tokenized_dataset:
+ prompt_len = len(prompt_token_ids)
+ if prompt_len < 4 or output_len < 4:
+ # Prune too short sequences.
+ continue
+ if prompt_len > 1024 or prompt_len + output_len > 2048:
+ # Prune too long sequences.
+ continue
+ filtered_dataset.append((prompt, prompt_len, output_len))
+
+ # Sample the requests.
+ sampled_requests = random.sample(filtered_dataset, num_requests)
+ return sampled_requests
+
+
+def run_vllm(
+ requests: List[Tuple[str, int, int]],
+ model: str,
+ tokenizer: str,
+ quantization: Optional[str],
+ tensor_parallel_size: int,
+ seed: int,
+ n: int,
+ use_beam_search: bool,
+ trust_remote_code: bool,
+ dtype: str,
+ max_model_len: Optional[int],
+ enforce_eager: bool,
+ kv_cache_dtype: str,
+) -> float:
+ from vllm import LLM, SamplingParams
+ llm = LLM(
+ model=model,
+ tokenizer=tokenizer,
+ quantization=quantization,
+ tensor_parallel_size=tensor_parallel_size,
+ seed=seed,
+ trust_remote_code=trust_remote_code,
+ dtype=dtype,
+ max_model_len=max_model_len,
+ enforce_eager=enforce_eager,
+ kv_cache_dtype=kv_cache_dtype,
+ )
+
+ # Add the requests to the engine.
+ for prompt, _, output_len in requests:
+ sampling_params = SamplingParams(
+ n=n,
+ temperature=0.0 if use_beam_search else 1.0,
+ top_p=1.0,
+ use_beam_search=use_beam_search,
+ ignore_eos=True,
+ max_tokens=output_len,
+ )
+ # FIXME(woosuk): Do not use internal method.
+ llm._add_request(
+ prompt=prompt,
+ prompt_token_ids=None,
+ sampling_params=sampling_params,
+ )
+
+ start = time.perf_counter()
+ # FIXME(woosuk): Do not use internal method.
+ llm._run_engine(use_tqdm=True)
+ end = time.perf_counter()
+ return end - start
+
+
+def run_hf(
+ requests: List[Tuple[str, int, int]],
+ model: str,
+ tokenizer: PreTrainedTokenizerBase,
+ n: int,
+ use_beam_search: bool,
+ max_batch_size: int,
+ trust_remote_code: bool,
+) -> float:
+ assert not use_beam_search
+ llm = AutoModelForCausalLM.from_pretrained(
+ model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
+ if llm.config.model_type == "llama":
+ # To enable padding in the HF backend.
+ tokenizer.pad_token = tokenizer.eos_token
+ llm = llm.cuda()
+
+ pbar = tqdm(total=len(requests))
+ start = time.perf_counter()
+ batch: List[str] = []
+ max_prompt_len = 0
+ max_output_len = 0
+ for i in range(len(requests)):
+ prompt, prompt_len, output_len = requests[i]
+ # Add the prompt to the batch.
+ batch.append(prompt)
+ max_prompt_len = max(max_prompt_len, prompt_len)
+ max_output_len = max(max_output_len, output_len)
+ if len(batch) < max_batch_size and i != len(requests) - 1:
+ # Check if we can add more requests to the batch.
+ _, next_prompt_len, next_output_len = requests[i + 1]
+ if (max(max_prompt_len, next_prompt_len) +
+ max(max_output_len, next_output_len)) <= 2048:
+ # We can add more requests to the batch.
+ continue
+
+ # Generate the sequences.
+ input_ids = tokenizer(batch, return_tensors="pt",
+ padding=True).input_ids
+ llm_outputs = llm.generate(
+ input_ids=input_ids.cuda(),
+ do_sample=not use_beam_search,
+ num_return_sequences=n,
+ temperature=1.0,
+ top_p=1.0,
+ use_cache=True,
+ max_new_tokens=max_output_len,
+ )
+ # Include the decoding time.
+ tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
+ pbar.update(len(batch))
+
+ # Clear the batch.
+ batch = []
+ max_prompt_len = 0
+ max_output_len = 0
+ end = time.perf_counter()
+ return end - start
+
+
+def run_mii(
+ requests: List[Tuple[str, int, int]],
+ model: str,
+ tensor_parallel_size: int,
+ output_len: int,
+) -> float:
+ from mii import pipeline
+ llm = pipeline(model, tensor_parallel=tensor_parallel_size)
+ prompts = [prompt for prompt, _, _ in requests]
+
+ start = time.perf_counter()
+ llm(prompts, max_new_tokens=output_len)
+ end = time.perf_counter()
+ return end - start
+
+
+def main(args: argparse.Namespace):
+ print(args)
+ random.seed(args.seed)
+
+ # Sample the requests.
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer, trust_remote_code=args.trust_remote_code)
+ if args.dataset is None:
+ # Synthesize a prompt with the given input length.
+ prompt = "hi" * (args.input_len - 1)
+ requests = [(prompt, args.input_len, args.output_len)
+ for _ in range(args.num_prompts)]
+ else:
+ requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
+ args.output_len)
+
+ if args.backend == "vllm":
+ elapsed_time = run_vllm(requests, args.model, args.tokenizer,
+ args.quantization, args.tensor_parallel_size,
+ args.seed, args.n, args.use_beam_search,
+ args.trust_remote_code, args.dtype,
+ args.max_model_len, args.enforce_eager,
+ args.kv_cache_dtype)
+ elif args.backend == "hf":
+ assert args.tensor_parallel_size == 1
+ elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
+ args.use_beam_search, args.hf_max_batch_size,
+ args.trust_remote_code)
+ elif args.backend == "mii":
+ elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
+ args.output_len)
+ else:
+ raise ValueError(f"Unknown backend: {args.backend}")
+ total_num_tokens = sum(prompt_len + output_len
+ for _, prompt_len, output_len in requests)
+ print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
+ f"{total_num_tokens / elapsed_time:.2f} tokens/s")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Benchmark the throughput.")
+ parser.add_argument("--backend",
+ type=str,
+ choices=["vllm", "hf", "mii"],
+ default="vllm")
+ parser.add_argument("--dataset",
+ type=str,
+ default=None,
+ help="Path to the dataset.")
+ parser.add_argument("--input-len",
+ type=int,
+ default=None,
+ help="Input prompt length for each request")
+ parser.add_argument("--output-len",
+ type=int,
+ default=None,
+ help="Output length for each request. Overrides the "
+ "output length from the dataset.")
+ parser.add_argument("--model", type=str, default="facebook/opt-125m")
+ parser.add_argument("--tokenizer", type=str, default=None)
+ parser.add_argument('--quantization',
+ '-q',
+ choices=['awq', 'gptq', 'squeezellm', None],
+ default=None)
+ parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
+ parser.add_argument("--n",
+ type=int,
+ default=1,
+ help="Number of generated sequences per prompt.")
+ parser.add_argument("--use-beam-search", action="store_true")
+ parser.add_argument("--num-prompts",
+ type=int,
+ default=1000,
+ help="Number of prompts to process.")
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--hf-max-batch-size",
+ type=int,
+ default=None,
+ help="Maximum batch size for HF backend.")
+ parser.add_argument('--trust-remote-code',
+ action='store_true',
+ help='trust remote code from huggingface')
+ parser.add_argument(
+ '--max-model-len',
+ type=int,
+ default=None,
+ help='Maximum length of a sequence (including prompt and output). '
+ 'If None, will be derived from the model.')
+ parser.add_argument(
+ '--dtype',
+ type=str,
+ default='auto',
+ choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
+ help='data type for model weights and activations. '
+ 'The "auto" option will use FP16 precision '
+ 'for FP32 and FP16 models, and BF16 precision '
+ 'for BF16 models.')
+ parser.add_argument("--enforce-eager",
+ action="store_true",
+ help="enforce eager execution")
+ parser.add_argument(
+ "--kv-cache-dtype",
+ type=str,
+ choices=["auto", "fp8_e5m2"],
+ default="auto",
+ help=
+ 'Data type for kv cache storage. If "auto", will use model data type.')
+ args = parser.parse_args()
+ if args.tokenizer is None:
+ args.tokenizer = args.model
+ if args.dataset is None:
+ assert args.input_len is not None
+ assert args.output_len is not None
+ else:
+ assert args.input_len is None
+
+ if args.backend == "vllm":
+ if args.hf_max_batch_size is not None:
+ raise ValueError("HF max batch size is only for HF backend.")
+ elif args.backend == "hf":
+ if args.hf_max_batch_size is None:
+ raise ValueError("HF max batch size is required for HF backend.")
+ if args.quantization is not None:
+ raise ValueError("Quantization is only for vLLM backend.")
+ elif args.backend == "mii":
+ if args.dtype != "auto":
+ raise ValueError("dtype must be auto for MII backend.")
+ if args.n != 1:
+ raise ValueError("n must be 1 for MII backend.")
+ if args.use_beam_search:
+ raise ValueError("Beam search is not supported for MII backend.")
+ if args.quantization is not None:
+ raise ValueError("Quantization is only for vLLM backend.")
+ if args.hf_max_batch_size is not None:
+ raise ValueError("HF max batch size is only for HF backend.")
+ if args.tokenizer != args.model:
+ raise ValueError("Tokenizer must be the same as the model for MII "
+ "backend.")
+ main(args)
diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..56fe1b921d44ecee919d10ae49feb3725ce3448d
--- /dev/null
+++ b/benchmarks/kernels/benchmark_paged_attention.py
@@ -0,0 +1,196 @@
+from typing import Optional
+import argparse
+import random
+import time
+
+import torch
+
+from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
+from vllm._C import ops
+
+NUM_BLOCKS = 1024
+PARTITION_SIZE = 512
+
+
+@torch.inference_mode()
+def main(
+ version: str,
+ num_seqs: int,
+ context_len: int,
+ num_query_heads: int,
+ num_kv_heads: int,
+ head_size: int,
+ use_alibi: bool,
+ block_size: int,
+ dtype: torch.dtype,
+ seed: int,
+ do_profile: bool,
+ kv_cache_dtype: Optional[str] = None,
+) -> None:
+ random.seed(seed)
+ torch.random.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+ scale = float(1.0 / (head_size**0.5))
+ query = torch.empty(num_seqs,
+ num_query_heads,
+ head_size,
+ dtype=dtype,
+ device="cuda")
+ query.uniform_(-scale, scale)
+
+ assert num_query_heads % num_kv_heads == 0
+ alibi_slopes = None
+ if use_alibi:
+ alibi_slopes = torch.randn(num_query_heads,
+ dtype=torch.float,
+ device="cuda")
+
+ context_lens = [context_len for _ in range(num_seqs)]
+ max_context_len = max(context_lens)
+ context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
+
+ # Create the block tables.
+ max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
+ block_tables = []
+ for _ in range(num_seqs):
+ block_table = [
+ random.randint(0, NUM_BLOCKS - 1)
+ for _ in range(max_num_blocks_per_seq)
+ ]
+ block_tables.append(block_table)
+ block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
+
+ # Create the KV cache.
+ key_caches, value_caches = create_kv_caches_with_random(
+ NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
+ dtype)
+ key_cache, value_cache = key_caches[0], value_caches[0]
+
+ # Prepare for the paged attention kernel.
+ output = torch.empty_like(query)
+ if version == "v2":
+ num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
+ PARTITION_SIZE)
+ tmp_output = torch.empty(
+ size=(num_seqs, num_query_heads, num_partitions, head_size),
+ dtype=output.dtype,
+ device=output.device,
+ )
+ exp_sums = torch.empty(
+ size=(num_seqs, num_query_heads, num_partitions),
+ dtype=torch.float32,
+ device=output.device,
+ )
+ max_logits = torch.empty_like(exp_sums)
+
+ def run_benchmark(num_iters: int, profile: bool = False) -> float:
+ torch.cuda.synchronize()
+ if profile:
+ torch.cuda.cudart().cudaProfilerStart()
+ start_time = time.perf_counter()
+
+ for _ in range(num_iters):
+ if version == "v1":
+ ops.paged_attention_v1(
+ output,
+ query,
+ key_cache,
+ value_cache,
+ num_kv_heads,
+ scale,
+ block_tables,
+ context_lens,
+ block_size,
+ max_context_len,
+ alibi_slopes,
+ kv_cache_dtype,
+ )
+ elif version == "v2":
+ ops.paged_attention_v2(
+ output,
+ exp_sums,
+ max_logits,
+ tmp_output,
+ query,
+ key_cache,
+ value_cache,
+ num_kv_heads,
+ scale,
+ block_tables,
+ context_lens,
+ block_size,
+ max_context_len,
+ alibi_slopes,
+ kv_cache_dtype,
+ )
+ else:
+ raise ValueError(f"Invalid version: {version}")
+ torch.cuda.synchronize()
+
+ end_time = time.perf_counter()
+ if profile:
+ torch.cuda.cudart().cudaProfilerStart()
+ return (end_time - start_time) / num_iters
+
+ # Warmup.
+ print("Warming up...")
+ run_benchmark(num_iters=3, profile=False)
+
+ # Benchmark.
+ if do_profile:
+ latency = run_benchmark(num_iters=1, profile=True)
+ else:
+ latency = run_benchmark(num_iters=100, profile=False)
+ print(f"Kernel running time: {latency * 1000000:.3f} us")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description="Benchmark the paged attention kernel.")
+ parser.add_argument("--version",
+ type=str,
+ choices=["v1", "v2"],
+ default="v2")
+ parser.add_argument("--batch-size", type=int, default=8)
+ parser.add_argument("--context-len", type=int, default=4096)
+ parser.add_argument("--num-query-heads", type=int, default=64)
+ parser.add_argument("--num-kv-heads", type=int, default=8)
+ parser.add_argument("--head-size",
+ type=int,
+ choices=[64, 80, 96, 112, 128, 256],
+ default=128)
+ parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
+ parser.add_argument("--use-alibi", action="store_true")
+ parser.add_argument("--dtype",
+ type=str,
+ choices=["half", "bfloat16", "float"],
+ default="half")
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--profile", action="store_true")
+ parser.add_argument(
+ "--kv-cache-dtype",
+ type=str,
+ choices=["auto", "fp8_e5m2"],
+ default="auto",
+ help=
+ 'Data type for kv cache storage. If "auto", will use model data type.')
+ args = parser.parse_args()
+ print(args)
+
+ if args.num_query_heads % args.num_kv_heads != 0:
+ raise ValueError("num_query_heads must be divisible by num_kv_heads")
+ main(
+ version=args.version,
+ num_seqs=args.batch_size,
+ context_len=args.context_len,
+ num_query_heads=args.num_query_heads,
+ num_kv_heads=args.num_kv_heads,
+ head_size=args.head_size,
+ block_size=args.block_size,
+ use_alibi=args.use_alibi,
+ dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
+ seed=args.seed,
+ do_profile=args.profile,
+ kv_cache_dtype=args.kv_cache_dtype,
+ )
diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh
new file mode 100755
index 0000000000000000000000000000000000000000..bdb25b78d85b477877e61b4fb8f277714aa2a851
--- /dev/null
+++ b/benchmarks/launch_tgi_server.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+PORT=8000
+MODEL=$1
+TOKENS=$2
+
+docker run --gpus all --shm-size 1g -p $PORT:80 \
+ -v $PWD/data:/data \
+ ghcr.io/huggingface/text-generation-inference:0.8 \
+ --model-id $MODEL \
+ --sharded false \
+ --max-input-length 1024 \
+ --max-total-tokens 2048 \
+ --max-best-of 5 \
+ --max-concurrent-requests 5000 \
+ --max-batch-total-tokens $TOKENS
diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu
new file mode 100644
index 0000000000000000000000000000000000000000..5ba9ab178d5a42a95c4ce622db14136de85b44ad
--- /dev/null
+++ b/csrc/activation_kernels.cu
@@ -0,0 +1,118 @@
+#include
+#include
+#include
+
+#include "cuda_compat.h"
+#include "dispatch_utils.h"
+
+namespace vllm {
+
+template
+__device__ __forceinline__ T silu(const T& x) {
+ // x * sigmoid(x)
+ return (T) (((float) x) / (1.0f + expf((float) -x)));
+}
+
+template
+__global__ void silu_and_mul_kernel(
+ scalar_t* __restrict__ out, // [..., d]
+ const scalar_t* __restrict__ input, // [..., 2, d]
+ const int d) {
+ const int64_t token_idx = blockIdx.x;
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
+ const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
+ const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
+ out[token_idx * d + idx] = silu(x) * y;
+ }
+}
+
+} // namespace vllm
+
+void silu_and_mul(
+ torch::Tensor& out, // [..., d]
+ torch::Tensor& input) // [..., 2 * d]
+{
+ int64_t num_tokens = input.numel() / input.size(-1);
+ int d = input.size(-1) / 2;
+
+ dim3 grid(num_tokens);
+ dim3 block(std::min(d, 1024));
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ VLLM_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(),
+ "silu_and_mul_kernel",
+ [&] {
+ vllm::silu_and_mul_kernel<<>>(
+ out.data_ptr(),
+ input.data_ptr(),
+ d);
+ });
+}
+
+namespace vllm {
+
+// Element-wise activation kernel template.
+template
+__global__ void activation_kernel(
+ scalar_t* __restrict__ out, // [..., d]
+ const scalar_t* __restrict__ input, // [..., d]
+ const int d) {
+ const int64_t token_idx = blockIdx.x;
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
+ const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
+ out[token_idx * d + idx] = ACT_FN(x);
+ }
+}
+
+} // namespace vllm
+
+// Launch element-wise activation kernel.
+#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
+ int d = input.size(-1); \
+ int64_t num_tokens = input.numel() / d; \
+ dim3 grid(num_tokens); \
+ dim3 block(std::min(d, 1024)); \
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
+ VLLM_DISPATCH_FLOATING_TYPES( \
+ input.scalar_type(), \
+ "activation_kernel", \
+ [&] { \
+ vllm::activation_kernel><<>>( \
+ out.data_ptr(), \
+ input.data_ptr(), \
+ d); \
+ });
+
+namespace vllm {
+
+template
+__device__ __forceinline__ T gelu_new_kernel(const T& x) {
+ const float x3 = (float) (x * x * x);
+ const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
+ return ((T) 0.5) * x * (((T) 1.0) + t);
+}
+
+template
+__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
+ const float f = (float) x;
+ const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
+ return ((T) 0.5) * x * (((T) 1.0) + t);
+}
+
+} // namespace vllm
+
+void gelu_new(
+ torch::Tensor& out, // [..., d]
+ torch::Tensor& input) // [..., d]
+{
+ LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
+}
+
+void gelu_fast(
+ torch::Tensor& out, // [..., d]
+ torch::Tensor& input) // [..., d]
+{
+ LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
+}
diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h
new file mode 100644
index 0000000000000000000000000000000000000000..61748e6b1eee6ffab04d3da2b0da7da6546fef43
--- /dev/null
+++ b/csrc/attention/attention_dtypes.h
@@ -0,0 +1,7 @@
+#pragma once
+
+#include "attention_generic.cuh"
+#include "dtype_float16.cuh"
+#include "dtype_float32.cuh"
+#include "dtype_bfloat16.cuh"
+#include "dtype_fp8_e5m2.cuh"
diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..31fb401cbe2c158bb7e3dfad266e794e5da58abc
--- /dev/null
+++ b/csrc/attention/attention_generic.cuh
@@ -0,0 +1,64 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include
+
+namespace vllm {
+
+// A vector type to store Q, K, V elements.
+template
+struct Vec {};
+
+// A vector type to store FP32 accumulators.
+template
+struct FloatVec {};
+
+// Template vector operations.
+template
+inline __device__ Acc mul(A a, B b);
+
+template
+inline __device__ float sum(T v);
+
+template
+inline __device__ float dot(T a, T b) {
+ return sum(mul(a, b));
+}
+
+template
+inline __device__ float dot(T a, T b) {
+ return sum(mul(a, b));
+}
+
+template
+inline __device__ void zero(T& dst) {
+ constexpr int WORDS = sizeof(T) / 4;
+ union {
+ T raw;
+ uint32_t words[WORDS];
+ } tmp;
+
+#pragma unroll
+ for (int ii = 0; ii < WORDS; ++ii) {
+ tmp.words[ii] = 0u;
+ }
+ dst = tmp.raw;
+}
+
+} // namespace vllm
diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a5ddeac74044004f9a728cbeb57dd339175ebf44
--- /dev/null
+++ b/csrc/attention/attention_kernels.cu
@@ -0,0 +1,951 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifdef USE_ROCM
+#include
+#endif
+
+#include
+#include
+#include
+
+#include "attention_dtypes.h"
+#include "attention_utils.cuh"
+#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+
+#include
+
+#ifndef USE_ROCM
+#define WARP_SIZE 32
+#else
+#define WARP_SIZE warpSize
+#endif
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
+
+namespace vllm {
+
+// Utility function for attention softmax.
+template
+inline __device__ float block_sum(float* red_smem, float sum) {
+ // Decompose the thread index into warp / lane.
+ int warp = threadIdx.x / WARP_SIZE;
+ int lane = threadIdx.x % WARP_SIZE;
+
+ // Compute the sum per warp.
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+ sum += VLLM_SHFL_XOR_SYNC(sum, mask);
+ }
+
+ // Warp leaders store the data to shared memory.
+ if (lane == 0) {
+ red_smem[warp] = sum;
+ }
+
+ // Make sure the data is in shared memory.
+ __syncthreads();
+
+ // The warps compute the final sums.
+ if (lane < NUM_WARPS) {
+ sum = red_smem[lane];
+ }
+
+ // Parallel reduction inside the warp.
+#pragma unroll
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+ sum += VLLM_SHFL_XOR_SYNC(sum, mask);
+ }
+
+ // Broadcast to other threads.
+ return VLLM_SHFL_SYNC(sum, 0);
+}
+
+// TODO(woosuk): Merge the last two dimensions of the grid.
+// Grid: (num_heads, num_seqs, max_num_partitions).
+template<
+ typename scalar_t,
+ typename cache_t,
+ int HEAD_SIZE,
+ int BLOCK_SIZE,
+ int NUM_THREADS,
+ bool IS_FP8_E5M2_KV_CACHE,
+ int PARTITION_SIZE = 0> // Zero means no partitioning.
+__device__ void paged_attention_kernel(
+ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
+ float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
+ const int num_kv_heads, // [num_heads]
+ const float scale,
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
+ const int* __restrict__ context_lens, // [num_seqs]
+ const int max_num_blocks_per_seq,
+ const float* __restrict__ alibi_slopes, // [num_heads]
+ const int q_stride,
+ const int kv_block_stride,
+ const int kv_head_stride) {
+ const int seq_idx = blockIdx.y;
+ const int partition_idx = blockIdx.z;
+ const int max_num_partitions = gridDim.z;
+ constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
+ const int context_len = context_lens[seq_idx];
+ if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
+ // No work to do. Terminate the thread block.
+ return;
+ }
+
+ const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
+ const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
+
+ // [start_block_idx, end_block_idx) is the range of blocks to process.
+ const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
+ const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
+ const int num_blocks = end_block_idx - start_block_idx;
+
+ // [start_token_idx, end_token_idx) is the range of tokens to process.
+ const int start_token_idx = start_block_idx * BLOCK_SIZE;
+ const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
+ const int num_tokens = end_token_idx - start_token_idx;
+
+ constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+ constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
+ assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
+ constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ const int thread_idx = threadIdx.x;
+ const int warp_idx = thread_idx / WARP_SIZE;
+ const int lane = thread_idx % WARP_SIZE;
+
+ const int head_idx = blockIdx.x;
+ const int num_heads = gridDim.x;
+ const int num_queries_per_kv = num_heads / num_kv_heads;
+ const int kv_head_idx = head_idx / num_queries_per_kv;
+ const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
+
+ // A vector type to store a part of a key or a query.
+ // The vector size is configured in such a way that the threads in a thread group
+ // fetch or compute 16 bytes at a time.
+ // For example, if the size of a thread group is 4 and the data type is half,
+ // then the vector size is 16 / (4 * sizeof(half)) == 2.
+ constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
+ using K_vec = typename Vec::Type;
+ using Q_vec = typename Vec::Type;
+#ifdef ENABLE_FP8_E5M2
+ using Quant_vec = typename Vec::Type;
+#endif
+
+ constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
+ constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
+
+ const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
+ const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
+
+ // Load the query to registers.
+ // Each thread in a thread group has a different part of the query.
+ // For example, if the the thread group size is 4, then the first thread in the group
+ // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
+ // th vectors of the query, and so on.
+ // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
+ const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
+ __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
+#pragma unroll
+ for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
+ const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
+ q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE);
+ }
+ __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
+
+ // Memory planning.
+ extern __shared__ char shared_mem[];
+ // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
+ float* logits = reinterpret_cast(shared_mem);
+ // Workspace for reduction.
+ __shared__ float red_smem[2 * NUM_WARPS];
+
+ // x == THREAD_GROUP_SIZE * VEC_SIZE
+ // Each thread group fetches x elements from the key at a time.
+ constexpr int x = 16 / sizeof(cache_t);
+ float qk_max = -FLT_MAX;
+
+ // Iterate over the key blocks.
+ // Each warp fetches a block of keys for each iteration.
+ // Each thread group in a warp fetches a key from the block, and computes
+ // dot product with the query.
+ const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+ // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
+ // because int32 can lead to overflow when this variable is multiplied by large numbers
+ // (e.g., kv_block_stride).
+ const int64_t physical_block_number = static_cast(block_table[block_idx]);
+
+ // Load a key to registers.
+ // Each thread in a thread group has a different part of the key.
+ // For example, if the the thread group size is 4, then the first thread in the group
+ // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
+ // vectors of the key, and so on.
+ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
+ const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+ K_vec k_vecs[NUM_VECS_PER_THREAD];
+
+#pragma unroll
+ for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
+ const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ + kv_head_idx * kv_head_stride
+ + physical_block_offset * x;
+ const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
+ const int offset1 = (vec_idx * VEC_SIZE) / x;
+ const int offset2 = (vec_idx * VEC_SIZE) % x;
+ if constexpr (IS_FP8_E5M2_KV_CACHE) {
+#ifdef ENABLE_FP8_E5M2
+ Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+ // Vector conversion from Quant_vec to K_vec.
+ k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant);
+#else
+ assert(false);
+#endif
+ } else {
+ k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+ }
+ }
+
+ // Compute dot product.
+ // This includes a reduction across the threads in the same thread group.
+ float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs);
+ // Add the ALiBi bias if slopes are given.
+ qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
+
+ if (thread_group_offset == 0) {
+ // Store the partial reductions to shared memory.
+ // NOTE(woosuk): It is required to zero out the masked logits.
+ const bool mask = token_idx >= context_len;
+ logits[token_idx - start_token_idx] = mask ? 0.f : qk;
+ // Update the max value.
+ qk_max = mask ? qk_max : fmaxf(qk_max, qk);
+ }
+ }
+ }
+
+ // Perform reduction across the threads in the same warp to get the
+ // max qk value for each "warp" (not across the thread block yet).
+ // The 0-th thread of each thread group already has its max qk value.
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
+ qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
+ }
+ if (lane == 0) {
+ red_smem[warp_idx] = qk_max;
+ }
+ __syncthreads();
+
+ // TODO(woosuk): Refactor this part.
+ // Get the max qk value for the sequence.
+ qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+ qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
+ }
+ // Broadcast the max qk value to all threads.
+ qk_max = VLLM_SHFL_SYNC(qk_max, 0);
+
+ // Get the sum of the exp values.
+ float exp_sum = 0.f;
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+ float val = __expf(logits[i] - qk_max);
+ logits[i] = val;
+ exp_sum += val;
+ }
+ exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum);
+
+ // Compute softmax.
+ const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+ logits[i] *= inv_sum;
+ }
+ __syncthreads();
+
+ // If partitioning is enabled, store the max logit and exp_sum.
+ if (USE_PARTITIONING && thread_idx == 0) {
+ float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+ + head_idx * max_num_partitions
+ + partition_idx;
+ *max_logits_ptr = qk_max;
+ float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+ + head_idx * max_num_partitions
+ + partition_idx;
+ *exp_sums_ptr = exp_sum;
+ }
+
+ // Each thread will fetch 16 bytes from the value cache at a time.
+ constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
+ using V_vec = typename Vec::Type;
+ using L_vec = typename Vec::Type;
+#ifdef ENABLE_FP8_E5M2
+ using V_quant_vec = typename Vec::Type;
+#endif
+ using Float_L_vec = typename FloatVec::Type;
+
+ constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
+ constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
+ constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
+
+ // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
+ float accs[NUM_ROWS_PER_THREAD];
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ accs[i] = 0.f;
+ }
+
+ scalar_t zero_value;
+ zero(zero_value);
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+ // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
+ // because int32 can lead to overflow when this variable is multiplied by large numbers
+ // (e.g., kv_block_stride).
+ const int64_t physical_block_number = static_cast(block_table[block_idx]);
+ const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+ L_vec logits_vec;
+ from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx));
+
+ const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ + kv_head_idx * kv_head_stride;
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE) {
+ const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
+ V_vec v_vec;
+ if constexpr (IS_FP8_E5M2_KV_CACHE) {
+#ifdef ENABLE_FP8_E5M2
+ V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset);
+ // Vector conversion from V_quant_vec to V_vec.
+ v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec);
+#else
+ assert(false);
+#endif
+ } else {
+ v_vec = *reinterpret_cast(v_ptr + offset);
+ }
+ if (block_idx == num_context_blocks - 1) {
+ // NOTE(woosuk): When v_vec contains the tokens that are out of the context,
+ // we should explicitly zero out the values since they may contain NaNs.
+ // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
+ scalar_t* v_vec_ptr = reinterpret_cast(&v_vec);
+#pragma unroll
+ for (int j = 0; j < V_VEC_SIZE; j++) {
+ v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
+ }
+ }
+ accs[i] += dot(logits_vec, v_vec);
+ }
+ }
+ }
+
+ // Perform reduction within each warp.
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ float acc = accs[i];
+#pragma unroll
+ for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
+ acc += VLLM_SHFL_XOR_SYNC(acc, mask);
+ }
+ accs[i] = acc;
+ }
+
+ // NOTE(woosuk): A barrier is required because the shared memory space for logits
+ // is reused for the output.
+ __syncthreads();
+
+ // Perform reduction across warps.
+ float* out_smem = reinterpret_cast(shared_mem);
+#pragma unroll
+ for (int i = NUM_WARPS; i > 1; i /= 2) {
+ int mid = i / 2;
+ // Upper warps write to shared memory.
+ if (warp_idx >= mid && warp_idx < i) {
+ float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+ dst[row_idx] = accs[i];
+ }
+ }
+ }
+ __syncthreads();
+
+ // Lower warps update the output.
+ if (warp_idx < mid) {
+ const float* src = &out_smem[warp_idx * HEAD_SIZE];
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+ accs[i] += src[row_idx];
+ }
+ }
+ }
+ __syncthreads();
+ }
+
+ // Write the final output.
+ if (warp_idx == 0) {
+ scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ + head_idx * max_num_partitions * HEAD_SIZE
+ + partition_idx * HEAD_SIZE;
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+ from_float(*(out_ptr + row_idx), accs[i]);
+ }
+ }
+ }
+}
+
+// Grid: (num_heads, num_seqs, 1).
+template<
+ typename scalar_t,
+ typename cache_t,
+ int HEAD_SIZE,
+ int BLOCK_SIZE,
+ int NUM_THREADS,
+ bool IS_FP8_E5M2_KV_CACHE>
+__global__ void paged_attention_v1_kernel(
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
+ const int num_kv_heads, // [num_heads]
+ const float scale,
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
+ const int* __restrict__ context_lens, // [num_seqs]
+ const int max_num_blocks_per_seq,
+ const float* __restrict__ alibi_slopes, // [num_heads]
+ const int q_stride,
+ const int kv_block_stride,
+ const int kv_head_stride) {
+ paged_attention_kernel(
+ /* exp_sums */ nullptr, /* max_logits */ nullptr,
+ out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
+ max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
+}
+
+// Grid: (num_heads, num_seqs, max_num_partitions).
+template<
+ typename scalar_t,
+ typename cache_t,
+ int HEAD_SIZE,
+ int BLOCK_SIZE,
+ int NUM_THREADS,
+ bool IS_FP8_E5M2_KV_CACHE,
+ int PARTITION_SIZE>
+__global__ void paged_attention_v2_kernel(
+ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
+ float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
+ scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
+ const int num_kv_heads, // [num_heads]
+ const float scale,
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
+ const int* __restrict__ context_lens, // [num_seqs]
+ const int max_num_blocks_per_seq,
+ const float* __restrict__ alibi_slopes, // [num_heads]
+ const int q_stride,
+ const int kv_block_stride,
+ const int kv_head_stride) {
+ paged_attention_kernel(
+ exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
+ block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
+ q_stride, kv_block_stride, kv_head_stride);
+}
+
+// Grid: (num_heads, num_seqs).
+template<
+ typename scalar_t,
+ int HEAD_SIZE,
+ int NUM_THREADS,
+ int PARTITION_SIZE>
+__global__ void paged_attention_v2_reduce_kernel(
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
+ const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
+ const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
+ const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
+ const int* __restrict__ context_lens, // [num_seqs]
+ const int max_num_partitions) {
+ const int num_heads = gridDim.x;
+ const int head_idx = blockIdx.x;
+ const int seq_idx = blockIdx.y;
+ const int context_len = context_lens[seq_idx];
+ const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
+ if (num_partitions == 1) {
+ // No need to reduce. Only copy tmp_out to out.
+ scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+ const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ + head_idx * max_num_partitions * HEAD_SIZE;
+ for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
+ out_ptr[i] = tmp_out_ptr[i];
+ }
+ // Terminate the thread block.
+ return;
+ }
+
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ const int warp_idx = threadIdx.x / WARP_SIZE;
+ const int lane = threadIdx.x % WARP_SIZE;
+
+ // Size: 2 * num_partitions.
+ extern __shared__ char shared_mem[];
+ // Workspace for reduction.
+ __shared__ float red_smem[2 * NUM_WARPS];
+
+ // Load max logits to shared memory.
+ float* shared_max_logits = reinterpret_cast(shared_mem);
+ const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+ + head_idx * max_num_partitions;
+ float max_logit = -FLT_MAX;
+ for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+ const float l = max_logits_ptr[i];
+ shared_max_logits[i] = l;
+ max_logit = fmaxf(max_logit, l);
+ }
+ __syncthreads();
+
+ // Get the global max logit.
+ // Reduce within the warp.
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+ max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
+ }
+ if (lane == 0) {
+ red_smem[warp_idx] = max_logit;
+ }
+ __syncthreads();
+ // Reduce across warps.
+ max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+ max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
+ }
+ // Broadcast the max value to all threads.
+ max_logit = VLLM_SHFL_SYNC(max_logit, 0);
+
+ // Load rescaled exp sums to shared memory.
+ float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions);
+ const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+ + head_idx * max_num_partitions;
+ float global_exp_sum = 0.0f;
+ for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+ float l = shared_max_logits[i];
+ float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
+ global_exp_sum += rescaled_exp_sum;
+ shared_exp_sums[i] = rescaled_exp_sum;
+ }
+ __syncthreads();
+ global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum);
+ const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
+
+ // Aggregate tmp_out to out.
+ const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ + head_idx * max_num_partitions * HEAD_SIZE;
+ scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+#pragma unroll
+ for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
+ float acc = 0.0f;
+ for (int j = 0; j < num_partitions; ++j) {
+ acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
+ }
+ from_float(out_ptr[i], acc);
+ }
+}
+
+} // namespace vllm
+
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
+ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
+ ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \
+ vllm::paged_attention_v1_kernel<<>>( \
+ out_ptr, \
+ query_ptr, \
+ key_cache_ptr, \
+ value_cache_ptr, \
+ num_kv_heads, \
+ scale, \
+ block_tables_ptr, \
+ context_lens_ptr, \
+ max_num_blocks_per_seq, \
+ alibi_slopes_ptr, \
+ q_stride, \
+ kv_block_stride, \
+ kv_head_stride);
+
+// TODO(woosuk): Tune NUM_THREADS.
+template<
+ typename T,
+ typename CACHE_T,
+ int BLOCK_SIZE,
+ bool IS_FP8_E5M2_KV_CACHE,
+ int NUM_THREADS = 128>
+void paged_attention_v1_launcher(
+ torch::Tensor& out,
+ torch::Tensor& query,
+ torch::Tensor& key_cache,
+ torch::Tensor& value_cache,
+ int num_kv_heads,
+ float scale,
+ torch::Tensor& block_tables,
+ torch::Tensor& context_lens,
+ int max_context_len,
+ const c10::optional& alibi_slopes) {
+ int num_seqs = query.size(0);
+ int num_heads = query.size(1);
+ int head_size = query.size(2);
+ int max_num_blocks_per_seq = block_tables.size(1);
+ int q_stride = query.stride(0);
+ int kv_block_stride = key_cache.stride(0);
+ int kv_head_stride = key_cache.stride(1);
+
+ int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+ assert(head_size % thread_group_size == 0);
+
+ // NOTE: alibi_slopes is optional.
+ const float* alibi_slopes_ptr = alibi_slopes ?
+ reinterpret_cast(alibi_slopes.value().data_ptr())
+ : nullptr;
+
+ T* out_ptr = reinterpret_cast(out.data_ptr());
+ T* query_ptr = reinterpret_cast(query.data_ptr());
+ CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr());
+ CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr());
+ int* block_tables_ptr = block_tables.data_ptr();
+ int* context_lens_ptr = context_lens.data_ptr();
+
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
+ int logits_size = padded_max_context_len * sizeof(float);
+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+ // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
+ // Keep that in sync with the logic here!
+ int shared_mem_size = std::max(logits_size, outputs_size);
+
+ dim3 grid(num_heads, num_seqs, 1);
+ dim3 block(NUM_THREADS);
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ switch (head_size) {
+ // NOTE(woosuk): To reduce the compilation time, we only compile for the
+ // head sizes that we use in the model. However, we can easily extend this
+ // to support any head size which is a multiple of 16.
+ case 64:
+ LAUNCH_PAGED_ATTENTION_V1(64);
+ break;
+ case 80:
+ LAUNCH_PAGED_ATTENTION_V1(80);
+ break;
+ case 96:
+ LAUNCH_PAGED_ATTENTION_V1(96);
+ break;
+ case 112:
+ LAUNCH_PAGED_ATTENTION_V1(112);
+ break;
+ case 128:
+ LAUNCH_PAGED_ATTENTION_V1(128);
+ break;
+ case 256:
+ LAUNCH_PAGED_ATTENTION_V1(256);
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported head size: ", head_size);
+ break;
+ }
+}
+
+#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
+ paged_attention_v1_launcher( \
+ out, \
+ query, \
+ key_cache, \
+ value_cache, \
+ num_kv_heads, \
+ scale, \
+ block_tables, \
+ context_lens, \
+ max_context_len, \
+ alibi_slopes);
+
+// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
+// 1, 2, 4, 64, 128, 256.
+#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
+ switch (block_size) { \
+ case 8: \
+ CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
+ break; \
+ case 16: \
+ CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
+ break; \
+ case 32: \
+ CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
+ break; \
+ default: \
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+ break; \
+ }
+
+void paged_attention_v1(
+ torch::Tensor& out, // [num_seqs, num_heads, head_size]
+ torch::Tensor& query, // [num_seqs, num_heads, head_size]
+ torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
+ int num_kv_heads, // [num_heads]
+ float scale,
+ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
+ torch::Tensor& context_lens, // [num_seqs]
+ int block_size,
+ int max_context_len,
+ const c10::optional& alibi_slopes,
+ const std::string& kv_cache_dtype) {
+ if (kv_cache_dtype == "auto") {
+ if (query.dtype() == at::ScalarType::Float) {
+ CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
+ } else if (query.dtype() == at::ScalarType::Half) {
+ CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
+ CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
+ } else {
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+ }
+ } else if (kv_cache_dtype == "fp8_e5m2") {
+ if (query.dtype() == at::ScalarType::Float) {
+ CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
+ } else if (query.dtype() == at::ScalarType::Half) {
+ CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
+ CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
+ } else {
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+ }
+ } else {
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+ }
+}
+
+#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
+ vllm::paged_attention_v2_kernel \
+ <<>>( \
+ exp_sums_ptr, \
+ max_logits_ptr, \
+ tmp_out_ptr, \
+ query_ptr, \
+ key_cache_ptr, \
+ value_cache_ptr, \
+ num_kv_heads, \
+ scale, \
+ block_tables_ptr, \
+ context_lens_ptr, \
+ max_num_blocks_per_seq, \
+ alibi_slopes_ptr, \
+ q_stride, \
+ kv_block_stride, \
+ kv_head_stride); \
+ vllm::paged_attention_v2_reduce_kernel \
+ <<>>( \
+ out_ptr, \
+ exp_sums_ptr, \
+ max_logits_ptr, \
+ tmp_out_ptr, \
+ context_lens_ptr, \
+ max_num_partitions);
+
+template<
+ typename T,
+ typename CACHE_T,
+ int BLOCK_SIZE,
+ bool IS_FP8_E5M2_KV_CACHE,
+ int NUM_THREADS = 128,
+ int PARTITION_SIZE = 512>
+void paged_attention_v2_launcher(
+ torch::Tensor& out,
+ torch::Tensor& exp_sums,
+ torch::Tensor& max_logits,
+ torch::Tensor& tmp_out,
+ torch::Tensor& query,
+ torch::Tensor& key_cache,
+ torch::Tensor& value_cache,
+ int num_kv_heads,
+ float scale,
+ torch::Tensor& block_tables,
+ torch::Tensor& context_lens,
+ int max_context_len,
+ const c10::optional& alibi_slopes) {
+ int num_seqs = query.size(0);
+ int num_heads = query.size(1);
+ int head_size = query.size(2);
+ int max_num_blocks_per_seq = block_tables.size(1);
+ int q_stride = query.stride(0);
+ int kv_block_stride = key_cache.stride(0);
+ int kv_head_stride = key_cache.stride(1);
+
+ int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+ assert(head_size % thread_group_size == 0);
+
+ // NOTE: alibi_slopes is optional.
+ const float* alibi_slopes_ptr = alibi_slopes ?
+ reinterpret_cast(alibi_slopes.value().data_ptr())
+ : nullptr;
+
+ T* out_ptr = reinterpret_cast(out.data_ptr());
+ float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr());
+ float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr());
+ T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr());
+ T* query_ptr = reinterpret_cast(query.data_ptr());
+ CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr());
+ CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr());
+ int* block_tables_ptr = block_tables.data_ptr();
+ int* context_lens_ptr = context_lens.data_ptr();
+
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
+ int logits_size = PARTITION_SIZE * sizeof(float);
+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+
+ // For paged attention v2 kernel.
+ dim3 grid(num_heads, num_seqs, max_num_partitions);
+ int shared_mem_size = std::max(logits_size, outputs_size);
+ // For paged attention v2 reduce kernel.
+ dim3 reduce_grid(num_heads, num_seqs);
+ int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
+
+ dim3 block(NUM_THREADS);
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ switch (head_size) {
+ // NOTE(woosuk): To reduce the compilation time, we only compile for the
+ // head sizes that we use in the model. However, we can easily extend this
+ // to support any head size which is a multiple of 16.
+ case 64:
+ LAUNCH_PAGED_ATTENTION_V2(64);
+ break;
+ case 80:
+ LAUNCH_PAGED_ATTENTION_V2(80);
+ break;
+ case 96:
+ LAUNCH_PAGED_ATTENTION_V2(96);
+ break;
+ case 112:
+ LAUNCH_PAGED_ATTENTION_V2(112);
+ break;
+ case 128:
+ LAUNCH_PAGED_ATTENTION_V2(128);
+ break;
+ case 256:
+ LAUNCH_PAGED_ATTENTION_V2(256);
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported head size: ", head_size);
+ break;
+ }
+}
+
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
+ paged_attention_v2_launcher( \
+ out, \
+ exp_sums, \
+ max_logits, \
+ tmp_out, \
+ query, \
+ key_cache, \
+ value_cache, \
+ num_kv_heads, \
+ scale, \
+ block_tables, \
+ context_lens, \
+ max_context_len, \
+ alibi_slopes);
+
+// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
+// 1, 2, 4, 64, 128, 256.
+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
+ switch (block_size) { \
+ case 8: \
+ CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
+ break; \
+ case 16: \
+ CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
+ break; \
+ case 32: \
+ CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
+ break; \
+ default: \
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+ break; \
+ }
+
+void paged_attention_v2(
+ torch::Tensor& out, // [num_seqs, num_heads, head_size]
+ torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
+ torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
+ torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
+ torch::Tensor& query, // [num_seqs, num_heads, head_size]
+ torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
+ int num_kv_heads, // [num_heads]
+ float scale,
+ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
+ torch::Tensor& context_lens, // [num_seqs]
+ int block_size,
+ int max_context_len,
+ const c10::optional& alibi_slopes,
+ const std::string& kv_cache_dtype) {
+ if (kv_cache_dtype == "auto") {
+ if (query.dtype() == at::ScalarType::Float) {
+ CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
+ } else if (query.dtype() == at::ScalarType::Half) {
+ CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
+ CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
+ } else {
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+ }
+ } else if (kv_cache_dtype == "fp8_e5m2") {
+ if (query.dtype() == at::ScalarType::Float) {
+ CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
+ } else if (query.dtype() == at::ScalarType::Half) {
+ CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
+ CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
+ } else {
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+ }
+ } else {
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+ }
+}
+
+#undef WARP_SIZE
+#undef MAX
+#undef MIN
+#undef DIVIDE_ROUND_UP
diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..ff64c4bd8f80c200647e688db1a74c711a9f709d
--- /dev/null
+++ b/csrc/attention/attention_utils.cuh
@@ -0,0 +1,56 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "../cuda_compat.h"
+#include "attention_dtypes.h"
+
+#include
+#include
+
+namespace vllm {
+
+// Q*K^T operation.
+template
+inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
+ using A_vec = typename FloatVec::Type;
+ // Compute the parallel products for Q*K^T (treat vector lanes separately).
+ A_vec qk_vec = mul(q[0], k[0]);
+#pragma unroll
+ for (int ii = 1; ii < N; ++ii) {
+ qk_vec = fma(q[ii], k[ii], qk_vec);
+ }
+
+ // Finalize the reduction across lanes.
+ float qk = sum(qk_vec);
+#pragma unroll
+ for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
+ qk += VLLM_SHFL_XOR_SYNC(qk, mask);
+ }
+ return qk;
+}
+
+template
+struct Qk_dot {
+ template
+ static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
+ return qk_dot_(q, k);
+ }
+};
+
+} // namespace vllm
diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..31e0cee01d2e1ff47b3381590c5cbc2c8446b556
--- /dev/null
+++ b/csrc/attention/dtype_bfloat16.cuh
@@ -0,0 +1,451 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "attention_generic.cuh"
+#include "dtype_float32.cuh"
+
+#ifndef USE_ROCM
+ #include
+ #include
+#else
+ #include
+ #include
+
+ typedef __hip_bfloat162 __nv_bfloat162;
+ typedef __hip_bfloat16 __nv_bfloat16;
+#endif
+
+#include
+
+namespace vllm {
+
+// Define custom BF16 vector data types.
+struct bf16_4_t {
+ __nv_bfloat162 x;
+ __nv_bfloat162 y;
+};
+
+struct bf16_8_t {
+ __nv_bfloat162 x;
+ __nv_bfloat162 y;
+ __nv_bfloat162 z;
+ __nv_bfloat162 w;
+};
+
+// BF16 vector types for Q, K, V.
+template<>
+struct Vec<__nv_bfloat16, 1> {
+ using Type = __nv_bfloat16;
+};
+template<>
+struct Vec<__nv_bfloat16, 2> {
+ using Type = __nv_bfloat162;
+};
+template<>
+struct Vec<__nv_bfloat16, 4> {
+ using Type = bf16_4_t;
+};
+template<>
+struct Vec<__nv_bfloat16, 8> {
+ using Type = bf16_8_t;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template<>
+struct FloatVec<__nv_bfloat16> {
+ using Type = float;
+};
+template<>
+struct FloatVec<__nv_bfloat162> {
+ using Type = float2;
+};
+template<>
+struct FloatVec {
+ using Type = Float4_;
+};
+template<>
+struct FloatVec {
+ using Type = Float8_;
+};
+
+// Utility functions for type conversions.
+inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ assert(false);
+#else
+ return __bfloat1622float2(val);
+#endif
+}
+
+inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ assert(false);
+#else
+ return __bfloat162bfloat162(val);
+#endif
+}
+
+// Vector addition.
+inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ assert(false);
+#else
+ #ifndef USE_ROCM
+ return a + b;
+ #else
+ return __hadd(a, b);
+ #endif
+#endif
+}
+
+inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ assert(false);
+#else
+ return __hadd2(a, b);
+#endif
+}
+
+inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
+ bf16_4_t c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ return c;
+}
+
+inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
+ bf16_8_t c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ c.z = add(a.z, b.z);
+ c.w = add(a.w, b.w);
+ return c;
+}
+
+inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
+ float2 fa = bf1622float2(a);
+ return add(fa, fb);
+}
+
+inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
+ Float4_ fc;
+ fc.x = add(a.x, fb.x);
+ fc.y = add(a.y, fb.y);
+ return fc;
+}
+
+inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
+ Float8_ fc;
+ fc.x = add(a.x, fb.x);
+ fc.y = add(a.y, fb.y);
+ fc.z = add(a.z, fb.z);
+ fc.w = add(a.w, fb.w);
+ return fc;
+}
+
+// Vector multiplication.
+template<>
+inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ assert(false);
+#else
+ return __hmul(a, b);
+#endif
+}
+
+template<>
+inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ assert(false);
+#else
+ return __hmul2(a, b);
+#endif
+}
+
+template<>
+inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
+ return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
+}
+
+template<>
+inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
+ bf16_4_t c;
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
+ return c;
+}
+
+template<>
+inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
+ __nv_bfloat162 s = bf162bf162(a);
+ bf16_4_t c;
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
+ return c;
+}
+
+template<>
+inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
+ bf16_8_t c;
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
+ c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
+ c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
+ return c;
+}
+
+template<>
+inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
+ __nv_bfloat162 s = bf162bf162(a);
+ bf16_8_t c;
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
+ c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
+ c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
+ return c;
+}
+
+template<>
+inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
+ float fa = __bfloat162float(a);
+ float fb = __bfloat162float(b);
+ return fa * fb;
+}
+
+template<>
+inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
+ float2 fa = bf1622float2(a);
+ float2 fb = bf1622float2(b);
+ return mul(fa, fb);
+}
+
+template<>
+inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
+ return mul(bf162bf162(a), b);
+}
+
+template<>
+inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
+ Float4_ fc;
+ fc.x = mul(a.x, b.x);
+ fc.y = mul(a.y, b.y);
+ return fc;
+}
+
+template<>
+inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
+ __nv_bfloat162 s = bf162bf162(a);
+ Float4_ fc;
+ fc.x = mul(s, b.x);
+ fc.y = mul(s, b.y);
+ return fc;
+}
+
+template<>
+inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
+ Float8_ fc;
+ fc.x = mul(a.x, b.x);
+ fc.y = mul(a.y, b.y);
+ fc.z = mul(a.z, b.z);
+ fc.w = mul(a.w, b.w);
+ return fc;
+}
+
+template<>
+inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
+ __nv_bfloat162 s = bf162bf162(a);
+ Float8_ fc;
+ fc.x = mul(s, b.x);
+ fc.y = mul(s, b.y);
+ fc.z = mul(s, b.z);
+ fc.w = mul(s, b.w);
+ return fc;
+}
+
+// Vector fused multiply-add.
+inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ assert(false);
+#else
+ return __hfma2(a, b, c);
+#endif
+}
+
+inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ assert(false);
+#else
+ return __hfma2(bf162bf162(a), b, c);
+#endif
+}
+
+inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
+ bf16_4_t d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ return d;
+}
+
+inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
+ __nv_bfloat162 s = bf162bf162(a);
+ bf16_4_t d;
+ d.x = fma(s, b.x, c.x);
+ d.y = fma(s, b.y, c.y);
+ return d;
+}
+
+inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
+ bf16_8_t d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ d.z = fma(a.z, b.z, c.z);
+ d.w = fma(a.w, b.w, c.w);
+ return d;
+}
+
+inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
+ __nv_bfloat162 s = bf162bf162(a);
+ bf16_8_t d;
+ d.x = fma(s, b.x, c.x);
+ d.y = fma(s, b.y, c.y);
+ d.z = fma(s, b.z, c.z);
+ d.w = fma(s, b.w, c.w);
+ return d;
+}
+
+inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
+ return __bfloat162float(a) * __bfloat162float(b) + fc;
+}
+
+inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
+ float2 fa = bf1622float2(a);
+ float2 fb = bf1622float2(b);
+ return fma(fa, fb, fc);
+}
+
+inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
+ return fma(bf162bf162(a), b, fc);
+}
+
+inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
+ Float4_ fd;
+ fd.x = fma(a.x, b.x, fc.x);
+ fd.y = fma(a.y, b.y, fc.y);
+ return fd;
+}
+
+inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
+ __nv_bfloat162 s = bf162bf162(a);
+ Float4_ fd;
+ fd.x = fma(s, b.x, fc.x);
+ fd.y = fma(s, b.y, fc.y);
+ return fd;
+}
+
+inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
+ Float8_ fd;
+ fd.x = fma(a.x, b.x, fc.x);
+ fd.y = fma(a.y, b.y, fc.y);
+ fd.z = fma(a.z, b.z, fc.z);
+ fd.w = fma(a.w, b.w, fc.w);
+ return fd;
+}
+
+inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
+ __nv_bfloat162 s = bf162bf162(a);
+ Float8_ fd;
+ fd.x = fma(s, b.x, fc.x);
+ fd.y = fma(s, b.y, fc.y);
+ fd.z = fma(s, b.z, fc.z);
+ fd.w = fma(s, b.w, fc.w);
+ return fd;
+}
+
+// Vector sum.
+template<>
+inline __device__ float sum(__nv_bfloat16 v) {
+ return __bfloat162float(v);
+}
+
+template<>
+inline __device__ float sum(__nv_bfloat162 v) {
+ float2 vf = bf1622float2(v);
+ return vf.x + vf.y;
+}
+
+template<>
+inline __device__ float sum(bf16_4_t v) {
+ return sum(v.x) + sum(v.y);
+}
+
+template<>
+inline __device__ float sum(bf16_8_t v) {
+ return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
+}
+
+// From float32 to bfloat16.
+inline __device__ void from_float(__nv_bfloat16& dst, float src) {
+ dst = __float2bfloat16(src);
+}
+
+inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ assert(false);
+#else
+ dst = __float22bfloat162_rn(src);
+#endif
+}
+
+inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ assert(false);
+#else
+ dst.x = __float22bfloat162_rn(src.x);
+ dst.y = __float22bfloat162_rn(src.y);
+#endif
+}
+
+inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ assert(false);
+#else
+ dst.x = __float22bfloat162_rn(src.x);
+ dst.y = __float22bfloat162_rn(src.y);
+ dst.z = __float22bfloat162_rn(src.z);
+ dst.w = __float22bfloat162_rn(src.w);
+#endif
+}
+
+// From bfloat16 to float32.
+inline __device__ float to_float(__nv_bfloat16 u) {
+ return __bfloat162float(u);
+}
+
+// Zero-out a variable.
+inline __device__ void zero(__nv_bfloat16& dst) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ assert(false);
+#else
+ // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
+ dst = __ushort_as_bfloat16((unsigned short)0x0000U);
+#endif
+}
+
+} // namespace vllm
diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..d3271e69cd69d93abe03539604b17a380eb094b8
--- /dev/null
+++ b/csrc/attention/dtype_float16.cuh
@@ -0,0 +1,502 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "attention_generic.cuh"
+#include "dtype_float32.cuh"
+
+#ifdef USE_ROCM
+ #include
+#endif
+
+#include
+
+namespace vllm {
+
+// FP16 vector types for Q, K, V.
+template<>
+struct Vec {
+ using Type = uint16_t;
+};
+template<>
+struct Vec {
+ using Type = uint32_t;
+};
+template<>
+struct Vec {
+ using Type = uint2;
+};
+template<>
+struct Vec {
+ using Type = uint4;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template<>
+struct FloatVec {
+ using Type = float;
+};
+template<>
+struct FloatVec {
+ using Type = float2;
+};
+template<>
+struct FloatVec {
+ using Type = Float4_;
+};
+template<>
+struct FloatVec {
+ using Type = Float8_;
+};
+
+// Utility functions for type conversions.
+inline __device__ uint32_t h0_h0(uint16_t a) {
+#ifndef USE_ROCM
+ uint32_t b;
+ asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
+ return b;
+#else
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+ tmp.u16[0] = a;
+ tmp.u16[1] = a;
+ return tmp.u32;
+#endif
+}
+
+inline __device__ float half_to_float(uint16_t h) {
+ float f;
+#ifndef USE_ROCM
+ asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
+#else
+ asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
+#endif
+ return f;
+}
+
+inline __device__ float2 half2_to_float2(uint32_t v) {
+#ifndef USE_ROCM
+ uint16_t lo, hi;
+ asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
+ return make_float2(half_to_float(lo), half_to_float(hi));
+#else
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+ tmp.u32 = v;
+ float2 ret;
+ ret.x = half_to_float(tmp.u16[0]);
+ ret.y = half_to_float(tmp.u16[1]);
+ return ret;
+#endif
+}
+
+inline __device__ uint16_t float_to_half(float f) {
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+#ifndef USE_ROCM
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
+#else
+ asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
+#endif
+ return tmp.u16[0];
+}
+
+inline __device__ uint32_t float2_to_half2(float2 f) {
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+#ifndef USE_ROCM
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+ asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
+ #else
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
+ #endif
+#else
+ tmp.u16[0] = float_to_half(f.x);
+ tmp.u16[1] = float_to_half(f.y);
+#endif
+ return tmp.u32;
+}
+
+// Vector addition.
+inline __device__ uint16_t add(uint16_t a, uint16_t b) {
+ uint16_t c;
+#ifndef USE_ROCM
+ asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+#else
+ asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+#endif
+ return c;
+}
+
+inline __device__ uint32_t add(uint32_t a, uint32_t b) {
+ uint32_t c;
+#ifndef USE_ROCM
+ asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+#else
+ asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+#endif
+ return c;
+}
+
+inline __device__ uint2 add(uint2 a, uint2 b) {
+ uint2 c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ return c;
+}
+
+inline __device__ uint4 add(uint4 a, uint4 b) {
+ uint4 c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ c.z = add(a.z, b.z);
+ c.w = add(a.w, b.w);
+ return c;
+}
+
+inline __device__ float2 add(uint32_t a, float2 fb) {
+ float2 fa = half2_to_float2(a);
+ return add(fa, fb);
+}
+
+inline __device__ Float4_ add(uint2 a, Float4_ fb) {
+ Float4_ fc;
+ fc.x = add(a.x, fb.x);
+ fc.y = add(a.y, fb.y);
+ return fc;
+}
+
+inline __device__ Float8_ add(uint4 a, Float8_ fb) {
+ Float8_ fc;
+ fc.x = add(a.x, fb.x);
+ fc.y = add(a.y, fb.y);
+ fc.z = add(a.z, fb.z);
+ fc.w = add(a.w, fb.w);
+ return fc;
+}
+
+// Vector multiplication.
+template<>
+inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
+ uint16_t c;
+#ifndef USE_ROCM
+ asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+#else
+ asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+#endif
+ return c;
+}
+
+template<>
+inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
+ uint32_t c;
+#ifndef USE_ROCM
+ asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+#else
+ asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+#endif
+ return c;
+}
+
+template<>
+inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
+ return mul(h0_h0(a), b);
+}
+
+template<>
+inline __device__ uint2 mul(uint2 a, uint2 b) {
+ uint2 c;
+ c.x = mul(a.x, b.x);
+ c.y = mul(a.y, b.y);
+ return c;
+}
+
+template<>
+inline __device__ uint2 mul(uint16_t a, uint2 b) {
+ uint32_t s = h0_h0(a);
+ uint2 c;
+ c.x = mul(s, b.x);
+ c.y = mul(s, b.y);
+ return c;
+}
+
+template<>
+inline __device__ uint4 mul(uint4 a, uint4 b) {
+ uint4 c;
+ c.x = mul(a.x, b.x);
+ c.y = mul(a.y, b.y);
+ c.z = mul(a.z, b.z);
+ c.w = mul(a.w, b.w);
+ return c;
+}
+
+template<>
+inline __device__ uint4 mul(uint16_t a, uint4 b) {
+ uint32_t s = h0_h0(a);
+ uint4 c;
+ c.x = mul(s, b.x);
+ c.y = mul(s, b.y);
+ c.z = mul(s, b.z);
+ c.w = mul(s, b.w);
+ return c;
+}
+
+template<>
+inline __device__ float mul(uint16_t a, uint16_t b) {
+ float fa = half_to_float(a);
+ float fb = half_to_float(b);
+ return fa * fb;
+}
+
+template<>
+inline __device__ float2 mul(uint32_t a, uint32_t b) {
+ float2 fa = half2_to_float2(a);
+ float2 fb = half2_to_float2(b);
+ return mul(fa, fb);
+}
+
+template<>
+inline __device__ float2 mul(uint16_t a, uint32_t b) {
+ return mul(h0_h0(a), b);
+}
+
+template<>
+inline __device__ Float4_ mul(uint2 a, uint2 b) {
+ Float4_ fc;
+ fc.x = mul(a.x, b.x);
+ fc.y = mul(a.y, b.y);
+ return fc;
+}
+
+template<>
+inline __device__ Float4_ mul(uint16_t a, uint2 b) {
+ uint32_t s = h0_h0(a);
+ Float4_ fc;
+ fc.x = mul(s, b.x);
+ fc.y = mul(s, b.y);
+ return fc;
+}
+
+template<>
+inline __device__ Float8_ mul(uint4 a, uint4 b) {
+ Float8_ fc;
+ fc.x = mul(a.x, b.x);
+ fc.y = mul(a.y, b.y);
+ fc.z = mul(a.z, b.z);
+ fc.w = mul(a.w, b.w);
+ return fc;
+}
+
+template<>
+inline __device__ Float8_ mul(uint16_t a, uint4 b) {
+ uint32_t s = h0_h0(a);
+ Float8_ fc;
+ fc.x = mul(s, b.x);
+ fc.y = mul(s, b.y);
+ fc.z = mul(s, b.z);
+ fc.w = mul(s, b.w);
+ return fc;
+}
+
+// Vector fused multiply-add.
+inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
+ uint32_t d;
+#ifndef USE_ROCM
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
+#else
+ asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
+#endif
+ return d;
+}
+
+inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
+ return fma(h0_h0(a), b, c);
+}
+
+inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
+ uint2 d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ return d;
+}
+
+inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
+ uint32_t s = h0_h0(a);
+ uint2 d;
+ d.x = fma(s, b.x, c.x);
+ d.y = fma(s, b.y, c.y);
+ return d;
+}
+
+inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
+ uint4 d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ d.z = fma(a.z, b.z, c.z);
+ d.w = fma(a.w, b.w, c.w);
+ return d;
+}
+
+inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
+ uint32_t s = h0_h0(a);
+ uint4 d;
+ d.x = fma(s, b.x, c.x);
+ d.y = fma(s, b.y, c.y);
+ d.z = fma(s, b.z, c.z);
+ d.w = fma(s, b.w, c.w);
+ return d;
+}
+
+inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
+ float fa = half_to_float(a);
+ float fb = half_to_float(b);
+ return fa * fb + fc;
+}
+
+inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
+ float2 fa = half2_to_float2(a);
+ float2 fb = half2_to_float2(b);
+ return fma(fa, fb, fc);
+}
+
+inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
+ return fma(h0_h0(a), b, fc);
+}
+
+inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
+ Float4_ fd;
+ fd.x = fma(a.x, b.x, fc.x);
+ fd.y = fma(a.y, b.y, fc.y);
+ return fd;
+}
+
+inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
+ uint32_t s = h0_h0(a);
+ Float4_ fd;
+ fd.x = fma(s, b.x, fc.x);
+ fd.y = fma(s, b.y, fc.y);
+ return fd;
+}
+
+inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
+ Float8_ fd;
+ fd.x = fma(a.x, b.x, fc.x);
+ fd.y = fma(a.y, b.y, fc.y);
+ fd.z = fma(a.z, b.z, fc.z);
+ fd.w = fma(a.w, b.w, fc.w);
+ return fd;
+}
+
+inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
+ uint32_t s = h0_h0(a);
+ Float8_ fd;
+ fd.x = fma(s, b.x, fc.x);
+ fd.y = fma(s, b.y, fc.y);
+ fd.z = fma(s, b.z, fc.z);
+ fd.w = fma(s, b.w, fc.w);
+ return fd;
+}
+
+// Vector sum.
+template<>
+inline __device__ float sum(uint16_t v) {
+ return half_to_float(v);
+}
+
+template<>
+inline __device__ float sum(uint32_t v) {
+ float2 tmp = half2_to_float2(v);
+ return tmp.x + tmp.y;
+}
+
+template<>
+inline __device__ float sum(uint2 v) {
+ uint32_t c = add(v.x, v.y);
+ return sum(c);
+}
+
+template<>
+inline __device__ float sum(uint4 v) {
+ uint32_t c = add(v.x, v.y);
+ c = add(c, v.z);
+ c = add(c, v.w);
+ return sum(c);
+}
+
+// From float32 to float16.
+inline __device__ void from_float(uint16_t& dst, float src) {
+ dst = float_to_half(src);
+}
+
+inline __device__ void from_float(uint32_t& dst, float2 src) {
+ dst = float2_to_half2(src);
+}
+
+inline __device__ void from_float(uint2& dst, Float4_ src) {
+ dst.x = float2_to_half2(src.x);
+ dst.y = float2_to_half2(src.y);
+}
+
+inline __device__ void from_float(uint4& dst, Float8_ src) {
+ dst.x = float2_to_half2(src.x);
+ dst.y = float2_to_half2(src.y);
+ dst.z = float2_to_half2(src.z);
+ dst.w = float2_to_half2(src.w);
+}
+
+// From float16 to float32.
+inline __device__ float to_float(uint16_t u) {
+ return half_to_float(u);
+}
+
+inline __device__ float2 to_float(uint32_t u) {
+ return half2_to_float2(u);
+}
+
+inline __device__ Float4_ to_float(uint2 u) {
+ Float4_ tmp;
+ tmp.x = half2_to_float2(u.x);
+ tmp.y = half2_to_float2(u.y);
+ return tmp;
+}
+
+inline __device__ Float8_ to_float(uint4 u) {
+ Float8_ tmp;
+ tmp.x = half2_to_float2(u.x);
+ tmp.y = half2_to_float2(u.y);
+ tmp.z = half2_to_float2(u.z);
+ tmp.w = half2_to_float2(u.w);
+ return tmp;
+}
+
+// Zero-out a variable.
+inline __device__ void zero(uint16_t& dst) {
+ dst = uint16_t(0);
+}
+
+} // namespace vllm
diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..b200d2d226eb04792ec3d18a48a5210c40a2d92b
--- /dev/null
+++ b/csrc/attention/dtype_float32.cuh
@@ -0,0 +1,273 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "attention_generic.cuh"
+
+#include
+
+namespace vllm {
+
+// Define custom FP32 vector data types.
+struct Float4_ {
+ float2 x;
+ float2 y;
+};
+
+struct Float8_ {
+ float2 x;
+ float2 y;
+ float2 z;
+ float2 w;
+};
+
+// FP32 vector types for Q, K, V.
+template<>
+struct Vec {
+ using Type = float;
+};
+template<>
+struct Vec {
+ using Type = float2;
+};
+template<>
+struct Vec {
+ using Type = float4;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template<>
+struct FloatVec {
+ using Type = float;
+};
+template<>
+struct FloatVec {
+ using Type = float2;
+};
+template<>
+struct FloatVec {
+ using Type = float4;
+};
+
+// Vector addition.
+inline __device__ float add(float a, float b) {
+ return a + b;
+}
+
+inline __device__ float2 add(float2 a, float2 b) {
+ float2 c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ return c;
+}
+
+inline __device__ float4 add(float4 a, float4 b) {
+ float4 c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ c.z = add(a.z, b.z);
+ c.w = add(a.w, b.w);
+ return c;
+}
+
+// Vector multiplication.
+template<>
+inline __device__ float mul(float a, float b) {
+ return a * b;
+}
+
+template<>
+inline __device__ float2 mul(float2 a, float2 b) {
+ float2 c;
+ c.x = a.x * b.x;
+ c.y = a.y * b.y;
+ return c;
+}
+
+template<>
+inline __device__ float2 mul(float a, float2 b) {
+ float2 c;
+ c.x = a * b.x;
+ c.y = a * b.y;
+ return c;
+}
+
+template<>
+inline __device__ float4 mul(float4 a, float4 b) {
+ float4 c;
+ c.x = a.x * b.x;
+ c.y = a.y * b.y;
+ c.z = a.z * b.z;
+ c.w = a.w * b.w;
+ return c;
+}
+
+template<>
+inline __device__ float4 mul(float a, float4 b) {
+ float4 c;
+ c.x = a * b.x;
+ c.y = a * b.y;
+ c.z = a * b.z;
+ c.w = a * b.w;
+ return c;
+}
+
+// Vector fused multiply-add.
+inline __device__ float fma(float a, float b, float c) {
+ return a * b + c;
+}
+
+inline __device__ float2 fma(float2 a, float2 b, float2 c) {
+ float2 d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ return d;
+}
+
+inline __device__ float2 fma(float a, float2 b, float2 c) {
+ float2 d;
+ d.x = fma(a, b.x, c.x);
+ d.y = fma(a, b.y, c.y);
+ return d;
+}
+
+inline __device__ float4 fma(float4 a, float4 b, float4 c) {
+ float4 d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ d.z = fma(a.z, b.z, c.z);
+ d.w = fma(a.w, b.w, c.w);
+ return d;
+}
+
+inline __device__ float4 fma(float a, float4 b, float4 c) {
+ float4 d;
+ d.x = fma(a, b.x, c.x);
+ d.y = fma(a, b.y, c.y);
+ d.z = fma(a, b.z, c.z);
+ d.w = fma(a, b.w, c.w);
+ return d;
+}
+
+inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
+ Float4_ d;
+ d.x = fma(a, b.x, c.x);
+ d.y = fma(a, b.y, c.y);
+ return d;
+}
+
+inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
+ Float8_ d;
+ d.x = fma(a, b.x, c.x);
+ d.y = fma(a, b.y, c.y);
+ d.z = fma(a, b.z, c.z);
+ d.w = fma(a, b.w, c.w);
+ return d;
+}
+
+// Vector sum.
+template<>
+inline __device__ float sum(float v) {
+ return v;
+}
+
+template<>
+inline __device__ float sum(float2 v) {
+ return v.x + v.y;
+}
+
+template<>
+inline __device__ float sum(float4 v) {
+ return v.x + v.y + v.z + v.w;
+}
+
+template<>
+inline __device__ float sum(Float4_ v) {
+ return v.x.x + v.x.y + v.y.x + v.y.y;
+}
+
+template<>
+inline __device__ float sum(Float8_ v) {
+ return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
+}
+
+// Vector dot product.
+inline __device__ float dot(float a, float b) {
+ return a * b;
+}
+
+inline __device__ float dot(float2 a, float2 b) {
+ float2 c = mul(a, b);
+ return c.x + c.y;
+}
+
+inline __device__ float dot(Float4_ a, Float4_ b) {
+ float2 acc = mul(a.x, b.x);
+ acc = fma(a.y, b.y, acc);
+ return acc.x + acc.y;
+}
+
+inline __device__ float dot(Float8_ a, Float8_ b) {
+ float2 acc = mul(a.x, b.x);
+ acc = fma(a.y, b.y, acc);
+ acc = fma(a.z, b.z, acc);
+ acc = fma(a.w, b.w, acc);
+ return acc.x + acc.y;
+}
+
+// From float to float.
+inline __device__ void from_float(float& dst, float src) {
+ dst = src;
+}
+
+inline __device__ void from_float(float2& dst, float2 src) {
+ dst = src;
+}
+
+inline __device__ void from_float(float4& dst, float4 src) {
+ dst = src;
+}
+
+// From float to float.
+inline __device__ float to_float(float u) {
+ return u;
+}
+
+inline __device__ float2 to_float(float2 u) {
+ return u;
+}
+
+inline __device__ float4 to_float(float4 u) {
+ return u;
+}
+
+inline __device__ Float4_ to_float(Float4_ u) {
+ return u;
+}
+
+inline __device__ Float8_ to_float(Float8_ u) {
+ return u;
+}
+
+// Zero-out a variable.
+inline __device__ void zero(float& dst) {
+ dst = 0.f;
+}
+
+} // namespace vllm
diff --git a/csrc/attention/dtype_fp8_e5m2.cuh b/csrc/attention/dtype_fp8_e5m2.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..0580fbb8e863f74f27e6499a86d4fcf92e462514
--- /dev/null
+++ b/csrc/attention/dtype_fp8_e5m2.cuh
@@ -0,0 +1,35 @@
+#pragma once
+
+#include "attention_generic.cuh"
+
+#include
+#ifdef ENABLE_FP8_E5M2
+#include
+#endif
+
+namespace vllm {
+#ifdef ENABLE_FP8_E5M2
+// fp8 vector types for quantization of kv cache
+
+template<>
+struct Vec {
+ using Type = uint8_t;
+};
+
+template<>
+struct Vec {
+ using Type = uint16_t;
+};
+
+template<>
+struct Vec {
+ using Type = uint32_t;
+};
+
+template<>
+struct Vec {
+ using Type = uint2;
+};
+#endif // ENABLE_FP8_E5M2
+
+} // namespace vllm
diff --git a/csrc/cache.h b/csrc/cache.h
new file mode 100644
index 0000000000000000000000000000000000000000..21c71830f7942cc90cd33fe0dd7c3e0fcc6732b6
--- /dev/null
+++ b/csrc/cache.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include
+
+#include
+#include
+
+void swap_blocks(
+ torch::Tensor& src,
+ torch::Tensor& dst,
+ const std::map& block_mapping);
+
+void copy_blocks(
+ std::vector& key_caches,
+ std::vector& value_caches,
+ const std::map>& block_mapping);
+
+void reshape_and_cache(
+ torch::Tensor& key,
+ torch::Tensor& value,
+ torch::Tensor& key_cache,
+ torch::Tensor& value_cache,
+ torch::Tensor& slot_mapping,
+ const std::string& kv_cache_dtype);
+
+void gather_cached_kv(
+ torch::Tensor& key,
+ torch::Tensor& value,
+ torch::Tensor& key_cache,
+ torch::Tensor& value_cache,
+ torch::Tensor& slot_mapping);
+
+// Just for unittest
+void convert_fp8_e5m2(
+ torch::Tensor& src_cache,
+ torch::Tensor& dst_cache);
diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu
new file mode 100644
index 0000000000000000000000000000000000000000..fe0159e40458502db6d47a95ee598025f111716b
--- /dev/null
+++ b/csrc/cache_kernels.cu
@@ -0,0 +1,474 @@
+#include
+#include
+#include
+
+#include "cuda_compat.h"
+#include "dispatch_utils.h"
+#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+
+#include
+#include
+#include
+#include
+
+void swap_blocks(
+ torch::Tensor& src,
+ torch::Tensor& dst,
+ const std::map& block_mapping) {
+ torch::Device src_device = src.device();
+ torch::Device dst_device = dst.device();
+ cudaMemcpyKind memcpy_type;
+ if (src_device.is_cuda() && dst_device.is_cuda()) {
+ TORCH_CHECK(
+ src_device.index() == dst_device.index(),
+ "src and dst must be on the same GPU");
+ memcpy_type = cudaMemcpyDeviceToDevice;
+ } else if (src_device.is_cuda() && dst_device.is_cpu()) {
+ memcpy_type = cudaMemcpyDeviceToHost;
+ } else if (src_device.is_cpu() && dst_device.is_cuda()) {
+ memcpy_type = cudaMemcpyHostToDevice;
+ } else {
+ TORCH_CHECK(false, "Invalid device combination");
+ }
+
+ char *src_ptr = static_cast(src.data_ptr());
+ char *dst_ptr = static_cast(dst.data_ptr());
+
+ const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
+ const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ // NOTE(woosuk): This can be slow if the number of blocks is large.
+ for (const auto& pair : block_mapping) {
+ int64_t src_block_number = pair.first;
+ int64_t dst_block_number = pair.second;
+ int64_t src_offset = src_block_number * block_size_in_bytes;
+ int64_t dst_offset = dst_block_number * block_size_in_bytes;
+ cudaMemcpyAsync(
+ dst_ptr + dst_offset,
+ src_ptr + src_offset,
+ block_size_in_bytes,
+ memcpy_type,
+ stream);
+ }
+}
+
+namespace vllm {
+
+// Grid: (num_layers, num_pairs)
+template
+__global__ void copy_blocks_kernel(
+ int64_t* key_cache_ptrs,
+ int64_t* value_cache_ptrs,
+ const int64_t* __restrict__ block_mapping,
+ const int numel_per_block) {
+ const int layer_idx = blockIdx.x;
+ const int pair_idx = blockIdx.y;
+
+ scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]);
+ scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]);
+ int64_t src_block_number = block_mapping[2 * pair_idx];
+ int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
+
+ const int64_t src_block_offset = src_block_number * numel_per_block;
+ const int64_t dst_block_offset = dst_block_number * numel_per_block;
+ for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+ int64_t src_offset = src_block_offset + i;
+ int64_t dst_offset = dst_block_offset + i;
+ key_cache[dst_offset] = key_cache[src_offset];
+ }
+ for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+ int64_t src_offset = src_block_offset + i;
+ int64_t dst_offset = dst_block_offset + i;
+ value_cache[dst_offset] = value_cache[src_offset];
+ }
+}
+
+} // namespace vllm
+
+void copy_blocks(
+ std::vector& key_caches,
+ std::vector& value_caches,
+ const std::map>& block_mapping) {
+ int num_layers = key_caches.size();
+ TORCH_CHECK(num_layers == value_caches.size());
+ if (num_layers == 0) {
+ return;
+ }
+ torch::Device cache_device = key_caches[0].device();
+ TORCH_CHECK(cache_device.is_cuda());
+
+ // Create data structures for the kernel.
+ // Create an array of pointers to the key and value caches.
+ int64_t key_cache_ptrs[num_layers];
+ int64_t value_cache_ptrs[num_layers];
+ for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
+ key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr());
+ value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr());
+ }
+ // Create block mapping array.
+ std::vector block_mapping_vec;
+ for (const auto& pair : block_mapping) {
+ int64_t src_block_number = pair.first;
+ for (int64_t dst_block_number : pair.second) {
+ block_mapping_vec.push_back(src_block_number);
+ block_mapping_vec.push_back(dst_block_number);
+ }
+ }
+ int64_t* block_mapping_array = block_mapping_vec.data();
+ int num_pairs = block_mapping_vec.size() / 2;
+
+ // Move the data structures to the GPU.
+ // NOTE: This synchronizes the CPU and GPU.
+ torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
+ key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
+ torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
+ value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
+ torch::Tensor block_mapping_tensor = torch::from_blob(
+ block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
+
+ // Launch the kernel.
+ const int numel_per_block = key_caches[0][0].numel();
+ dim3 grid(num_layers, num_pairs);
+ dim3 block(std::min(1024, numel_per_block));
+ const at::cuda::OptionalCUDAGuard device_guard(cache_device);
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
+ key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
+ vllm::copy_blocks_kernel<<>>(
+ key_cache_ptrs_tensor.data_ptr(),
+ value_cache_ptrs_tensor.data_ptr(),
+ block_mapping_tensor.data_ptr(),
+ numel_per_block);
+ }));
+}
+
+namespace vllm {
+
+template
+__global__ void reshape_and_cache_kernel(
+ const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
+ const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
+ cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
+ const int64_t* __restrict__ slot_mapping, // [num_tokens]
+ const int key_stride,
+ const int value_stride,
+ const int num_heads,
+ const int head_size,
+ const int block_size,
+ const int x) {
+ const int64_t token_idx = blockIdx.x;
+ const int64_t slot_idx = slot_mapping[token_idx];
+ if (slot_idx < 0) {
+ // Padding token that should be ignored.
+ return;
+ }
+
+ const int64_t block_idx = slot_idx / block_size;
+ const int64_t block_offset = slot_idx % block_size;
+
+ const int n = num_heads * head_size;
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ const int64_t src_key_idx = token_idx * key_stride + i;
+ const int64_t src_value_idx = token_idx * value_stride + i;
+
+ const int head_idx = i / head_size;
+ const int head_offset = i % head_size;
+ const int x_idx = head_offset / x;
+ const int x_offset = head_offset % x;
+
+ const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ + head_idx * (head_size / x) * block_size * x
+ + x_idx * block_size * x
+ + block_offset * x
+ + x_offset;
+ const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
+ + head_idx * head_size * block_size
+ + head_offset * block_size
+ + block_offset;
+ scalar_t tgt_key = key[src_key_idx];
+ scalar_t tgt_value = value[src_value_idx];
+ if constexpr (is_fp8_e5m2_kv_cache) {
+#ifdef ENABLE_FP8_E5M2
+ key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key);
+ value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value);
+#else
+ assert(false);
+#endif
+ } else {
+ key_cache[tgt_key_idx] = tgt_key;
+ value_cache[tgt_value_idx] = tgt_value;
+ }
+ }
+}
+
+} // namespace vllm
+
+#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
+ vllm::reshape_and_cache_kernel<<>>( \
+ reinterpret_cast(key.data_ptr()), \
+ reinterpret_cast(value.data_ptr()), \
+ reinterpret_cast(key_cache.data_ptr()), \
+ reinterpret_cast(value_cache.data_ptr()), \
+ slot_mapping.data_ptr(), \
+ key_stride, \
+ value_stride, \
+ num_heads, \
+ head_size, \
+ block_size, \
+ x);
+
+void reshape_and_cache(
+ torch::Tensor& key, // [num_tokens, num_heads, head_size]
+ torch::Tensor& value, // [num_tokens, num_heads, head_size]
+ torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
+ torch::Tensor& slot_mapping, // [num_tokens]
+ const std::string& kv_cache_dtype)
+{
+ int num_tokens = key.size(0);
+ int num_heads = key.size(1);
+ int head_size = key.size(2);
+ int block_size = key_cache.size(3);
+ int x = key_cache.size(4);
+
+ int key_stride = key.stride(0);
+ int value_stride = value.stride(0);
+
+ dim3 grid(num_tokens);
+ dim3 block(std::min(num_heads * head_size, 512));
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ if (kv_cache_dtype == "auto") {
+ if (key.dtype() == at::ScalarType::Float) {
+ CALL_RESHAPE_AND_CACHE(float, float, false);
+ } else if (key.dtype() == at::ScalarType::Half) {
+ CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
+ } else if (key.dtype() == at::ScalarType::BFloat16) {
+ CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
+ }
+ } else if (kv_cache_dtype == "fp8_e5m2") {
+ if (key.dtype() == at::ScalarType::Float) {
+ CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
+ } else if (key.dtype() == at::ScalarType::Half) {
+ CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
+ } else if (key.dtype() == at::ScalarType::BFloat16) {
+ CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
+ }
+ } else {
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+ }
+}
+
+namespace vllm {
+
+// Grid: (num_blocks, block_size).
+template
+__global__ void gather_cached_kv_kernel(
+ scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
+ scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
+ const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
+ const int* __restrict__ slot_mapping, // [num_tokens]
+ const int key_stride,
+ const int value_stride,
+ const int num_heads,
+ const int head_size,
+ const int block_size,
+ const int x) {
+ const int token_idx = blockIdx.x;
+ const int slot_idx = slot_mapping[token_idx];
+ const int block_idx = slot_idx / block_size;
+ const int block_offset = slot_idx % block_size;
+
+ const int num_tokens = num_heads * head_size;
+ for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
+ const int tgt_key_idx = token_idx * key_stride + i;
+ const int tgt_value_idx = token_idx * value_stride + i;
+
+ const int head_idx = i / head_size;
+ const int head_offset = i % head_size;
+ const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
+ const int x_offset = head_offset % x;
+
+ const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ + head_idx * (head_size / x) * block_size * x
+ + x_idx * block_size * x
+ + block_offset * x
+ + x_offset;
+ const int src_value_idx = block_idx * num_heads * head_size * block_size
+ + head_idx * head_size * block_size
+ + head_offset * block_size
+ + block_offset;
+
+ key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
+ value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
+ }
+}
+
+template
+__global__ void gather_cached_kv_kernel_optimized(
+ scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
+ scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
+ const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
+ const int *__restrict__ slot_mapping, // [num_tokens]
+ const int key_stride,
+ const int value_stride,
+ const int num_heads,
+ const int head_size,
+ const int block_size,
+ const int x)
+{
+ const int token_idx = blockIdx.x;
+ const int slot_idx = slot_mapping[token_idx];
+ const int block_idx = slot_idx / block_size;
+ const int block_offset = slot_idx % block_size;
+
+ const int dim = num_heads * head_size;
+ assert(dim % 4 == 0); // this is true for known use cases
+ const int unroll_factor = 4;
+ const int unrolled_dim = dim / unroll_factor;
+
+ for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
+ {
+ int tgt_key_indices[unroll_factor];
+ int tgt_value_indices[unroll_factor];
+ int src_key_indices[unroll_factor];
+ int src_value_indices[unroll_factor];
+ scalar_t keys_to_store[unroll_factor];
+ scalar_t values_to_store[unroll_factor];
+
+ #pragma unroll
+ for (int j = 0; j < unroll_factor; ++j)
+ {
+ int index = i + j * unrolled_dim;
+
+ const int tgt_key_idx = token_idx * key_stride + index;
+ const int tgt_value_idx = token_idx * value_stride + index;
+
+ const int head_idx = index / head_size;
+ const int head_offset = index % head_size;
+ const int x_idx = head_offset / x;
+ const int x_offset = head_offset % x;
+
+ const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ + head_idx * (head_size / x) * block_size * x
+ + x_idx * block_size * x
+ + block_offset * x
+ + x_offset;
+ const int src_value_idx = block_idx * num_heads * head_size * block_size
+ + head_idx * head_size * block_size
+ + head_offset * block_size
+ + block_offset;
+
+ tgt_key_indices[j] = tgt_key_idx;
+ tgt_value_indices[j] = tgt_value_idx;
+ src_key_indices[j] = src_key_idx;
+ src_value_indices[j] = src_value_idx;
+
+ keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
+ values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
+ }
+
+ #pragma unroll
+ for (int j = 0; j < unroll_factor; ++j)
+ {
+ key[tgt_key_indices[j]] = keys_to_store[j];
+ value[tgt_value_indices[j]] = values_to_store[j];
+ }
+ }
+}
+
+} // namespace vllm
+
+void gather_cached_kv(
+ torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
+ torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
+ torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
+ torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
+ torch::Tensor& slot_mapping) // [in] [num_tokens]
+{
+ int num_tokens = key.size(0);
+ int num_heads = key.size(1);
+ int head_size = key.size(2);
+ int block_size = key_cache.size(3);
+ int x = key_cache.size(4);
+
+ int key_stride = key.stride(0);
+ int value_stride = value.stride(0);
+
+ dim3 grid(num_tokens);
+ dim3 block(std::min(num_heads * head_size, 512));
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
+ key.scalar_type(),
+ "gather_cached_kv_kernel_optimized",
+ [&] {
+ vllm::gather_cached_kv_kernel_optimized<<>>(
+ key.data_ptr(),
+ value.data_ptr(),
+ key_cache.data_ptr(),
+ value_cache.data_ptr(),
+ slot_mapping.data_ptr(),
+ key_stride,
+ value_stride,
+ num_heads,
+ head_size,
+ block_size,
+ x);
+ });
+}
+
+namespace vllm {
+
+template
+__global__ void convert_fp8_e5m2_kernel(
+ const Tin* __restrict__ src_cache,
+ Tout* __restrict__ dst_cache,
+ const int64_t block_stride) {
+ const int64_t block_idx = blockIdx.x;
+ for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
+ int64_t idx = block_idx * block_stride + i;
+#ifdef ENABLE_FP8_E5M2
+ dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]);
+#else
+ assert(false);
+#endif
+ }
+}
+
+} // namespace vllm
+
+#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
+ vllm::convert_fp8_e5m2_kernel<<>>( \
+ reinterpret_cast(src_cache.data_ptr()), \
+ reinterpret_cast(dst_cache.data_ptr()), \
+ block_stride);
+
+void convert_fp8_e5m2(
+ torch::Tensor& src_cache,
+ torch::Tensor& dst_cache)
+{
+ int64_t num_blocks = src_cache.size(0);
+ int64_t block_stride = src_cache.stride(0);
+
+ dim3 grid(num_blocks);
+ dim3 block(std::min(block_stride, int64_t(512)));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ if (src_cache.dtype() == at::ScalarType::Float) {
+ CALL_CONVERT_FP8_E5M2(uint8_t, float);
+ } else if (src_cache.dtype() == at::ScalarType::Half) {
+ CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
+ } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
+ CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
+ } else if (dst_cache.dtype() == at::ScalarType::Float) {
+ CALL_CONVERT_FP8_E5M2(float, uint8_t);
+ } else if (dst_cache.dtype() == at::ScalarType::Half) {
+ CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
+ } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
+ CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
+ }
+}
diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h
new file mode 100644
index 0000000000000000000000000000000000000000..aa58dd73c148a087694209c29b70f6c5e0916678
--- /dev/null
+++ b/csrc/cuda_compat.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#ifndef USE_ROCM
+ #define VLLM_LDG(arg) __ldg(arg)
+#else
+ #define VLLM_LDG(arg) *(arg)
+#endif
+
+#ifndef USE_ROCM
+ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
+#else
+ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
+#endif
+
+#ifndef USE_ROCM
+ #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
+#else
+ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
+#endif
+
+#ifndef USE_ROCM
+ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
+ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
+#else
+ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
+ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
+#endif
+
diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..1483484faeb4a59f371f367e56f732ef496ac862
--- /dev/null
+++ b/csrc/cuda_utils.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#include
+
+int get_device_attribute(
+ int attribute,
+ int device_id);
+
+int get_max_shared_memory_per_block_device_attribute(
+ int device_id);
diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu
new file mode 100644
index 0000000000000000000000000000000000000000..1a443ef3620ccd38221f9f3a13106e064748b9cc
--- /dev/null
+++ b/csrc/cuda_utils_kernels.cu
@@ -0,0 +1,35 @@
+#ifdef USE_ROCM
+ #include
+ #include
+#endif
+int get_device_attribute(
+ int attribute,
+ int device_id)
+{
+ int device, value;
+ if (device_id < 0) {
+ cudaGetDevice(&device);
+ }
+ else {
+ device = device_id;
+ }
+ cudaDeviceGetAttribute(&value, static_cast(attribute), device);
+ return value;
+}
+
+
+int get_max_shared_memory_per_block_device_attribute(
+ int device_id)
+{
+int attribute;
+// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
+// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
+
+#ifdef USE_ROCM
+ attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
+#else
+ attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
+#endif
+
+ return get_device_attribute(attribute, device_id);
+}
diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu
new file mode 100644
index 0000000000000000000000000000000000000000..88e4af9d4a99f40de5b09b6f12b165f94d79557c
--- /dev/null
+++ b/csrc/custom_all_reduce.cu
@@ -0,0 +1,148 @@
+#include
+#include
+#include
+#include
+
+#include "custom_all_reduce.cuh"
+
+// fake pointer type
+using fptr_t = uint64_t;
+static_assert(sizeof(void *) == sizeof(fptr_t));
+
+fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
+ const std::vector &handles,
+ const std::vector &offsets, int rank,
+ bool full_nvlink) {
+ int world_size = offsets.size();
+ if (world_size > 8)
+ throw std::invalid_argument("world size > 8 is not supported");
+ if (world_size % 2 != 0)
+ throw std::invalid_argument("Odd num gpus is not supported for now");
+ if (world_size != handles.size())
+ throw std::invalid_argument(
+ "handles length should equal to offsets length");
+ if (rank < 0 || rank >= world_size)
+ throw std::invalid_argument("invalid rank passed in");
+
+ cudaIpcMemHandle_t ipc_handles[8];
+ for (int i = 0; i < world_size; i++) {
+ std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
+ }
+ return (fptr_t) new vllm::CustomAllreduce(
+ reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(),
+ rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
+}
+
+/**
+ * Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
+ * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
+ * because it allows transpose of contiguous slice (i.e. slicing the first
+ * dimension). Currently, we require this because stride information is not
+ * passed into the kernels and we treat input tensors as flat.
+ *
+ * Examples
+ * A = torch.zeros(3, 3, 3)
+ * 1. A: OK
+ * 2. A[1:]: OK
+ * 3. A.permute(2, 0, 1): OK
+ * 4. A[1:].permute(2, 0, 1): OK
+ * 5. A[None].expand(2, -1, -1, -1): Not OK
+ * 6. A[:, 1:, 1:]: Not OK
+ */
+bool _is_weak_contiguous(torch::Tensor &t) {
+ return t.is_contiguous() ||
+ (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
+ t.numel() * t.element_size());
+}
+
+bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
+ bool full_nvlink) {
+ auto inp_size = inp.numel() * inp.element_size();
+ // custom allreduce requires input byte size to be multiples of 16
+ if (inp_size % 16 != 0) return false;
+ if (!_is_weak_contiguous(inp)) return false;
+ if (world_size == 2 || full_nvlink) return inp_size <= max_size;
+ // 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
+ // <= 512k
+ return world_size <= 4 && inp_size <= 512 * 1024;
+}
+
+void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
+ cudaStream_t stream) {
+ auto fa = reinterpret_cast(_fa);
+ TORCH_CHECK(_is_weak_contiguous(out));
+ switch (out.scalar_type()) {
+ case at::ScalarType::Float: {
+ fa->allreduce(stream, reinterpret_cast(inp.data_ptr()),
+ reinterpret_cast(out.data_ptr()),
+ out.numel());
+ break;
+ }
+ case at::ScalarType::Half: {
+ fa->allreduce(stream, reinterpret_cast(inp.data_ptr()),
+ reinterpret_cast(out.data_ptr()),
+ out.numel());
+ break;
+ }
+#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
+ case at::ScalarType::BFloat16: {
+ fa->allreduce(
+ stream, reinterpret_cast