|
import enum |
|
import time |
|
from typing import Dict, Iterable, List, Optional, Tuple, Union |
|
|
|
from vllm.config import CacheConfig, SchedulerConfig |
|
from .block_manager import AllocStatus, BlockSpaceManager |
|
from vllm.core.policy import PolicyFactory |
|
from vllm.logger import init_logger |
|
from .sequence import ( |
|
Sequence, |
|
SequenceData, |
|
SequenceGroup, |
|
SequenceGroupMetadata, |
|
SequenceStatus, |
|
) |
|
|
|
logger = init_logger(__name__) |
|
|
|
|
|
class PreemptionMode(enum.Enum): |
|
"""Preemption modes. |
|
|
|
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory |
|
and swap them back in when the sequences are resumed. |
|
2. Recomputation: Discard the blocks of the preempted sequences and |
|
recompute them when the sequences are resumed, treating the sequences as |
|
new prompts. |
|
""" |
|
|
|
SWAP = enum.auto() |
|
RECOMPUTE = enum.auto() |
|
|
|
|
|
class SchedulerOutputs: |
|
|
|
def __init__( |
|
self, |
|
scheduled_seq_groups: List[SequenceGroup], |
|
prompt_run: bool, |
|
num_batched_tokens: int, |
|
blocks_to_swap_in: Dict[int, int], |
|
blocks_to_swap_out: Dict[int, int], |
|
blocks_to_copy: Dict[int, List[int]], |
|
ignored_seq_groups: List[SequenceGroup], |
|
) -> None: |
|
self.scheduled_seq_groups = scheduled_seq_groups |
|
self.prompt_run = prompt_run |
|
self.num_batched_tokens = num_batched_tokens |
|
self.blocks_to_swap_in = blocks_to_swap_in |
|
self.blocks_to_swap_out = blocks_to_swap_out |
|
self.blocks_to_copy = blocks_to_copy |
|
|
|
assert not (blocks_to_swap_in and blocks_to_swap_out) |
|
self.ignored_seq_groups = ignored_seq_groups |
|
|
|
def is_empty(self) -> bool: |
|
|
|
return ( |
|
not self.scheduled_seq_groups |
|
and not self.blocks_to_swap_in |
|
and not self.blocks_to_swap_out |
|
and not self.blocks_to_copy |
|
) |
|
|
|
|
|
class Scheduler: |
|
|
|
def __init__( |
|
self, |
|
scheduler_config: SchedulerConfig, |
|
cache_config: CacheConfig, |
|
) -> None: |
|
self.scheduler_config = scheduler_config |
|
self.cache_config = cache_config |
|
|
|
self.prompt_limit = min( |
|
self.scheduler_config.max_model_len, |
|
self.scheduler_config.max_num_batched_tokens, |
|
) |
|
|
|
|
|
self.policy = PolicyFactory.get_policy(policy_name="fcfs") |
|
|
|
self.block_manager = BlockSpaceManager( |
|
block_size=self.cache_config.block_size, |
|
num_gpu_blocks=self.cache_config.num_gpu_blocks, |
|
num_cpu_blocks=self.cache_config.num_cpu_blocks, |
|
sliding_window=self.cache_config.sliding_window, |
|
) |
|
|
|
|
|
|
|
self.waiting: List[SequenceGroup] = [] |
|
|
|
self.running: List[SequenceGroup] = [] |
|
|
|
self.swapped: List[SequenceGroup] = [] |
|
|
|
def add_seq_group(self, seq_group: SequenceGroup) -> None: |
|
|
|
self.waiting.append(seq_group) |
|
|
|
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: |
|
if isinstance(request_id, str): |
|
request_id = (request_id,) |
|
request_ids = set(request_id) |
|
for state_queue in [self.waiting, self.running, self.swapped]: |
|
|
|
|
|
|
|
for seq_group in reversed(state_queue): |
|
if seq_group.request_id in request_ids: |
|
|
|
state_queue.remove(seq_group) |
|
for seq in seq_group.get_seqs(): |
|
if seq.is_finished(): |
|
continue |
|
seq.status = SequenceStatus.FINISHED_ABORTED |
|
self.free_seq(seq) |
|
request_ids.remove(seq_group.request_id) |
|
if not request_ids: |
|
return |
|
|
|
def has_unfinished_seqs(self) -> bool: |
|
return self.waiting or self.running or self.swapped |
|
|
|
def get_num_unfinished_seq_groups(self) -> int: |
|
return len(self.waiting) + len(self.running) + len(self.swapped) |
|
|
|
def _schedule(self) -> SchedulerOutputs: |
|
|
|
blocks_to_swap_in: Dict[int, int] = {} |
|
blocks_to_swap_out: Dict[int, int] = {} |
|
blocks_to_copy: Dict[int, List[int]] = {} |
|
|
|
|
|
now = time.monotonic() |
|
|
|
|
|
if not self.swapped: |
|
ignored_seq_groups: List[SequenceGroup] = [] |
|
scheduled: List[SequenceGroup] = [] |
|
|
|
|
|
num_curr_seqs = sum( |
|
seq_group.get_max_num_running_seqs() for seq_group in self.running |
|
) |
|
seq_lens: List[int] = [] |
|
|
|
|
|
|
|
|
|
while self.waiting: |
|
seq_group = self.waiting[0] |
|
|
|
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) |
|
assert len(waiting_seqs) == 1, ( |
|
"Waiting sequence group should have only one prompt " "sequence." |
|
) |
|
num_prompt_tokens = waiting_seqs[0].get_len() |
|
if num_prompt_tokens > self.prompt_limit: |
|
logger.warning( |
|
f"Input prompt ({num_prompt_tokens} tokens) is too long" |
|
f" and exceeds limit of {self.prompt_limit}" |
|
) |
|
for seq in waiting_seqs: |
|
seq.status = SequenceStatus.FINISHED_IGNORED |
|
ignored_seq_groups.append(seq_group) |
|
self.waiting.pop(0) |
|
continue |
|
|
|
|
|
can_allocate = self.block_manager.can_allocate(seq_group) |
|
if can_allocate == AllocStatus.LATER: |
|
break |
|
elif can_allocate == AllocStatus.NEVER: |
|
logger.warning( |
|
f"Input prompt ({num_prompt_tokens} tokens) is too long" |
|
f" and exceeds the capacity of block_manager" |
|
) |
|
for seq in waiting_seqs: |
|
seq.status = SequenceStatus.FINISHED_IGNORED |
|
ignored_seq_groups.append(seq_group) |
|
self.waiting.pop(0) |
|
continue |
|
|
|
|
|
new_seq_lens = seq_lens + [num_prompt_tokens] |
|
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) |
|
if num_batched_tokens > self.scheduler_config.max_num_batched_tokens: |
|
break |
|
|
|
|
|
|
|
num_new_seqs = seq_group.get_max_num_running_seqs() |
|
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: |
|
break |
|
|
|
num_paddings = num_batched_tokens - sum(new_seq_lens) |
|
if num_paddings > self.scheduler_config.max_paddings: |
|
break |
|
seq_lens = new_seq_lens |
|
|
|
seq_group = self.waiting.pop(0) |
|
self._allocate(seq_group) |
|
self.running.append(seq_group) |
|
num_curr_seqs += num_new_seqs |
|
scheduled.append(seq_group) |
|
|
|
if scheduled or ignored_seq_groups: |
|
scheduler_outputs = SchedulerOutputs( |
|
scheduled_seq_groups=scheduled, |
|
prompt_run=True, |
|
num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0, |
|
blocks_to_swap_in=blocks_to_swap_in, |
|
blocks_to_swap_out=blocks_to_swap_out, |
|
blocks_to_copy=blocks_to_copy, |
|
ignored_seq_groups=ignored_seq_groups, |
|
) |
|
return scheduler_outputs |
|
|
|
|
|
|
|
|
|
|
|
self.running = self.policy.sort_by_priority(now, self.running) |
|
|
|
|
|
running: List[SequenceGroup] = [] |
|
preempted: List[SequenceGroup] = [] |
|
while self.running: |
|
seq_group = self.running.pop(0) |
|
while not self.block_manager.can_append_slot(seq_group): |
|
if self.running: |
|
|
|
victim_seq_group = self.running.pop(-1) |
|
self._preempt(victim_seq_group, blocks_to_swap_out) |
|
preempted.append(victim_seq_group) |
|
else: |
|
|
|
|
|
self._preempt(seq_group, blocks_to_swap_out) |
|
preempted.append(seq_group) |
|
break |
|
else: |
|
|
|
self._append_slot(seq_group, blocks_to_copy) |
|
running.append(seq_group) |
|
self.running = running |
|
|
|
|
|
self.swapped = self.policy.sort_by_priority(now, self.swapped) |
|
if not preempted: |
|
num_curr_seqs = sum( |
|
seq_group.get_max_num_running_seqs() for seq_group in self.running |
|
) |
|
|
|
while self.swapped: |
|
seq_group = self.swapped[0] |
|
|
|
if not self.block_manager.can_swap_in(seq_group): |
|
break |
|
|
|
|
|
|
|
num_new_seqs = seq_group.get_max_num_running_seqs() |
|
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: |
|
break |
|
|
|
seq_group = self.swapped.pop(0) |
|
self._swap_in(seq_group, blocks_to_swap_in) |
|
self._append_slot(seq_group, blocks_to_copy) |
|
num_curr_seqs += num_new_seqs |
|
self.running.append(seq_group) |
|
|
|
|
|
|
|
|
|
num_batched_tokens = sum( |
|
seq_group.num_seqs(status=SequenceStatus.RUNNING) |
|
for seq_group in self.running |
|
) |
|
|
|
scheduler_outputs = SchedulerOutputs( |
|
scheduled_seq_groups=self.running, |
|
prompt_run=False, |
|
num_batched_tokens=num_batched_tokens, |
|
blocks_to_swap_in=blocks_to_swap_in, |
|
blocks_to_swap_out=blocks_to_swap_out, |
|
blocks_to_copy=blocks_to_copy, |
|
ignored_seq_groups=[], |
|
) |
|
return scheduler_outputs |
|
|
|
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: |
|
|
|
|
|
|
|
scheduler_outputs = self._schedule() |
|
|
|
|
|
seq_group_metadata_list: List[SequenceGroupMetadata] = [] |
|
for seq_group in scheduler_outputs.scheduled_seq_groups: |
|
seq_data: Dict[int, SequenceData] = {} |
|
block_tables: Dict[int, List[int]] = {} |
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): |
|
seq_id = seq.seq_id |
|
seq_data[seq_id] = seq.data |
|
block_tables[seq_id] = self.block_manager.get_block_table(seq) |
|
|
|
seq_group_metadata = SequenceGroupMetadata( |
|
request_id=seq_group.request_id, |
|
is_prompt=scheduler_outputs.prompt_run, |
|
seq_data=seq_data, |
|
sampling_params=seq_group.sampling_params, |
|
block_tables=block_tables, |
|
) |
|
seq_group_metadata_list.append(seq_group_metadata) |
|
return seq_group_metadata_list, scheduler_outputs |
|
|
|
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: |
|
self.block_manager.fork(parent_seq, child_seq) |
|
|
|
def free_seq(self, seq: Sequence) -> None: |
|
self.block_manager.free(seq) |
|
|
|
def free_finished_seq_groups(self) -> None: |
|
self.running = [ |
|
seq_group for seq_group in self.running if not seq_group.is_finished() |
|
] |
|
|
|
def _allocate(self, seq_group: SequenceGroup) -> None: |
|
self.block_manager.allocate(seq_group) |
|
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): |
|
seq.status = SequenceStatus.RUNNING |
|
|
|
def _append_slot( |
|
self, |
|
seq_group: SequenceGroup, |
|
blocks_to_copy: Dict[int, List[int]], |
|
) -> None: |
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): |
|
ret = self.block_manager.append_slot(seq) |
|
if ret is not None: |
|
src_block, dst_block = ret |
|
if src_block in blocks_to_copy: |
|
blocks_to_copy[src_block].append(dst_block) |
|
else: |
|
blocks_to_copy[src_block] = [dst_block] |
|
|
|
def _preempt( |
|
self, |
|
seq_group: SequenceGroup, |
|
blocks_to_swap_out: Dict[int, int], |
|
preemption_mode: Optional[PreemptionMode] = None, |
|
) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if preemption_mode is None: |
|
if seq_group.get_max_num_running_seqs() == 1: |
|
preemption_mode = PreemptionMode.RECOMPUTE |
|
else: |
|
preemption_mode = PreemptionMode.SWAP |
|
if preemption_mode == PreemptionMode.RECOMPUTE: |
|
self._preempt_by_recompute(seq_group) |
|
elif preemption_mode == PreemptionMode.SWAP: |
|
self._preempt_by_swap(seq_group, blocks_to_swap_out) |
|
else: |
|
raise AssertionError("Invalid preemption mode.") |
|
|
|
def _preempt_by_recompute( |
|
self, |
|
seq_group: SequenceGroup, |
|
) -> None: |
|
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) |
|
assert len(seqs) == 1 |
|
for seq in seqs: |
|
seq.status = SequenceStatus.WAITING |
|
self.block_manager.free(seq) |
|
|
|
|
|
self.waiting.insert(0, seq_group) |
|
|
|
def _preempt_by_swap( |
|
self, |
|
seq_group: SequenceGroup, |
|
blocks_to_swap_out: Dict[int, int], |
|
) -> None: |
|
self._swap_out(seq_group, blocks_to_swap_out) |
|
self.swapped.append(seq_group) |
|
|
|
def _swap_in( |
|
self, |
|
seq_group: SequenceGroup, |
|
blocks_to_swap_in: Dict[int, int], |
|
) -> None: |
|
mapping = self.block_manager.swap_in(seq_group) |
|
blocks_to_swap_in.update(mapping) |
|
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): |
|
seq.status = SequenceStatus.RUNNING |
|
|
|
def _swap_out( |
|
self, |
|
seq_group: SequenceGroup, |
|
blocks_to_swap_out: Dict[int, int], |
|
) -> None: |
|
if not self.block_manager.can_swap_out(seq_group): |
|
|
|
|
|
raise RuntimeError( |
|
"Aborted due to the lack of CPU swap space. Please increase " |
|
"the swap space to avoid this error." |
|
) |
|
mapping = self.block_manager.swap_out(seq_group) |
|
blocks_to_swap_out.update(mapping) |
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): |
|
seq.status = SequenceStatus.SWAPPED |
|
|