Spaces:
Sleeping
Sleeping
# Copyright 2023 The vLLM team. | |
# Adapted from | |
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py | |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
from typing import Sequence | |
import torch | |
def ensure_divisibility(numerator, denominator): | |
"""Ensure that numerator is divisible by the denominator.""" | |
assert numerator % denominator == 0, "{} is not divisible by {}".format( | |
numerator, denominator) | |
def divide(numerator, denominator): | |
"""Ensure that numerator is divisible by the denominator and return | |
the division value.""" | |
ensure_divisibility(numerator, denominator) | |
return numerator // denominator | |
def split_tensor_along_last_dim( | |
tensor: torch.Tensor, | |
num_partitions: int, | |
contiguous_split_chunks: bool = False, | |
) -> Sequence[torch.Tensor]: | |
""" Split a tensor along its last dimension. | |
Arguments: | |
tensor: input tensor. | |
num_partitions: number of partitions to split the tensor | |
contiguous_split_chunks: If True, make each chunk contiguous | |
in memory. | |
Returns: | |
A list of Tensors | |
""" | |
# Get the size and dimension. | |
last_dim = tensor.dim() - 1 | |
last_dim_size = divide(tensor.size()[last_dim], num_partitions) | |
# Split. | |
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) | |
# NOTE: torch.split does not create contiguous tensors by default. | |
if contiguous_split_chunks: | |
return tuple(chunk.contiguous() for chunk in tensor_list) | |
return tensor_list | |