Spaces:
Sleeping
Sleeping
from typing import Optional, List, Tuple, TYPE_CHECKING | |
from vllm.config import ParallelConfig | |
from vllm.logger import init_logger | |
from vllm.utils import is_hip, set_cuda_visible_devices, get_ip | |
logger = init_logger(__name__) | |
try: | |
import ray | |
class RayWorkerVllm: | |
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be | |
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" | |
def __init__(self, init_cached_hf_modules=False) -> None: | |
if init_cached_hf_modules: | |
from transformers.dynamic_module_utils import init_hf_modules | |
init_hf_modules() | |
self.worker = None | |
def init_worker(self, worker_init_fn): | |
self.worker = worker_init_fn() | |
def __getattr__(self, name): | |
return getattr(self.worker, name) | |
def execute_method(self, method, *args, **kwargs): | |
executor = getattr(self, method) | |
return executor(*args, **kwargs) | |
def get_node_ip(self) -> str: | |
return get_ip() | |
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: | |
node_id = ray.get_runtime_context().get_node_id() | |
gpu_ids = ray.get_gpu_ids() | |
return node_id, gpu_ids | |
def set_cuda_visible_devices(self, device_ids) -> None: | |
set_cuda_visible_devices(device_ids) | |
except ImportError as e: | |
logger.warning(f"Failed to import Ray with {e!r}. " | |
"For distributed inference, please install Ray with " | |
"`pip install ray`.") | |
ray = None | |
RayWorkerVllm = None | |
if TYPE_CHECKING: | |
from ray.util.placement_group import PlacementGroup | |
def initialize_cluster( | |
parallel_config: ParallelConfig, | |
engine_use_ray: bool = False, | |
ray_address: Optional[str] = None, | |
) -> Optional["PlacementGroup"]: | |
"""Initialize the distributed cluster probably with Ray. | |
Args: | |
parallel_config: The configurations for parallel execution. | |
engine_use_ray: Whether to use Ray for async engine. | |
ray_address: The address of the Ray cluster. If None, uses | |
the default Ray cluster address. | |
Returns: | |
An optional `PlacementGroup`. It includes the specification | |
of the resources for each distributed worker. None if Ray is | |
not used. | |
""" | |
if parallel_config.worker_use_ray or engine_use_ray: | |
if ray is None: | |
raise ImportError( | |
"Ray is not installed. Please install Ray to use distributed " | |
"serving.") | |
# Connect to a ray cluster. | |
if is_hip(): | |
ray.init(address=ray_address, | |
ignore_reinit_error=True, | |
num_gpus=parallel_config.world_size) | |
else: | |
ray.init(address=ray_address, ignore_reinit_error=True) | |
if not parallel_config.worker_use_ray: | |
assert parallel_config.world_size == 1, ( | |
"Ray is required if parallel_config.world_size > 1.") | |
return None | |
# Create placement group for worker processes | |
current_placement_group = ray.util.get_current_placement_group() | |
if current_placement_group: | |
# We are in a placement group | |
bundles = current_placement_group.bundle_specs | |
# Verify that we can use the placement group. | |
gpu_bundles = 0 | |
for bundle in bundles: | |
bundle_gpus = bundle.get("GPU", 0) | |
if bundle_gpus > 1: | |
raise ValueError( | |
"Placement group bundle cannot have more than 1 GPU.") | |
if bundle_gpus: | |
gpu_bundles += 1 | |
if parallel_config.world_size > gpu_bundles: | |
raise ValueError( | |
"The number of required GPUs exceeds the total number of " | |
"available GPUs in the placement group.") | |
else: | |
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0) | |
if parallel_config.world_size > num_gpus_in_cluster: | |
raise ValueError( | |
"The number of required GPUs exceeds the total number of " | |
"available GPUs in the cluster.") | |
# Create a new placement group | |
placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size) | |
current_placement_group = ray.util.placement_group( | |
placement_group_specs) | |
# Wait until PG is ready - this will block until all | |
# requested resources are available, and will timeout | |
# if they cannot be provisioned. | |
ray.get(current_placement_group.ready(), timeout=1800) | |
return current_placement_group | |