"""Sequence and its related classes.""" |
import copy |
import enum |
from typing import Dict, List, Optional, Union |
import torch |
from vllm.block import LogicalTokenBlock |
from .sampling_params import SamplingParams |
PromptLogprobs = List[Optional[Dict[int, float]]] |
SampleLogprobs = List[Dict[int, float]] |
class SequenceStatus(enum.Enum): |
"""Status of a sequence.""" |
WAITING = enum.auto() |
RUNNING = enum.auto() |
SWAPPED = enum.auto() |
FINISHED_STOPPED = enum.auto() |
FINISHED_ABORTED = enum.auto() |
FINISHED_IGNORED = enum.auto() |
@staticmethod |
def is_finished(status: "SequenceStatus") -> bool: |
return status in [ |
SequenceStatus.FINISHED_STOPPED, |
SequenceStatus.FINISHED_ABORTED, |
SequenceStatus.FINISHED_IGNORED, |
] |
@staticmethod |
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: |
if status == SequenceStatus.FINISHED_STOPPED: |
finish_reason = "stop" |
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: |
finish_reason = "length" |
elif status == SequenceStatus.FINISHED_ABORTED: |
finish_reason = "abort" |
elif status == SequenceStatus.FINISHED_IGNORED: |
finish_reason = "length" |
else: |
finish_reason = None |
return finish_reason |
class SequenceData: |
"""Data associated with a sequence. |
Args: |
prompt_token_ids: The token IDs of the prompt. |
Attributes: |
prompt_token_ids: The token IDs of the prompt. |
output_token_ids: The token IDs of the output. |
cumulative_logprob: The cumulative log probability of the output. |
""" |
def __init__( |
self, |
prompt_token_ids: List[int], |
) -> None: |
self.prompt_token_ids = prompt_token_ids |
self.output_token_ids: List[int] = [] |
self.cumulative_logprob = 0.0 |
self.hidden_states: Optional[torch.Tensor] = None |
self.finished = False |
def append_token_id(self, token_id: int, logprob: float) -> None: |
if isinstance(self.cumulative_logprob, float): |
self.cumulative_logprob = [ |
0.0, |
] * len(logprob) |
self.output_token_ids.append(token_id) |
for i in range(len(self.cumulative_logprob)): |
self.cumulative_logprob[i] += logprob[i] |
def append_hidden_states(self, hidden_states: torch.Tensor) -> None: |
if self.hidden_states is None: |
self.hidden_states = hidden_states |
else: |
self.hidden_states = torch.cat([self.hidden_states, hidden_states], dim=0) |
def get_len(self) -> int: |
return len(self.output_token_ids) + len(self.prompt_token_ids) |
def get_prompt_len(self) -> int: |
return len(self.prompt_token_ids) |
def get_output_len(self) -> int: |
return len(self.output_token_ids) |
def get_token_ids(self) -> List[int]: |
return self.prompt_token_ids + self.output_token_ids |
def get_last_token_id(self) -> int: |
if not self.output_token_ids: |
return self.prompt_token_ids[-1] |
return self.output_token_ids[-1] |
def __repr__(self) -> str: |
return ( |
f"SequenceData(" |
f"prompt_token_ids={self.prompt_token_ids}, " |
f"output_token_ids={self.output_token_ids}, " |
f"cumulative_logprob={self.cumulative_logprob}), " |
f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}, " |
f"finished={self.finished})" |
) |
class Sequence: |
"""Stores the data, status, and block information of a sequence. |
Args: |
seq_id: The ID of the sequence. |
prompt: The prompt of the sequence. |
prompt_token_ids: The token IDs of the prompt. |
block_size: The block size of the sequence. Should be the same as the |
block size used by the block manager and cache engine. |
""" |
def __init__( |
self, |
seq_id: int, |
prompt: str, |
prompt_token_ids: List[int], |
block_size: int, |
) -> None: |
self.seq_id = seq_id |
self.prompt = prompt |
self.block_size = block_size |
self.data = SequenceData(prompt_token_ids) |
self.output_logprobs: SampleLogprobs = [] |
self.output_text = "" |
self.logical_token_blocks: List[LogicalTokenBlock] = [] |
self._append_tokens_to_blocks(prompt_token_ids) |
self.status = SequenceStatus.WAITING |
self.prefix_offset = 0 |
self.read_offset = 0 |
self.tokens: Optional[List[str]] = None |
def _append_logical_block(self) -> None: |
block = LogicalTokenBlock( |
block_number=len(self.logical_token_blocks), |
block_size=self.block_size, |
) |
self.logical_token_blocks.append(block) |
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: |
cursor = 0 |
while cursor < len(token_ids): |
if not self.logical_token_blocks: |
self._append_logical_block() |
last_block = self.logical_token_blocks[-1] |
if last_block.is_full(): |
self._append_logical_block() |
last_block = self.logical_token_blocks[-1] |
num_empty_slots = last_block.get_num_empty_slots() |
last_block.append_tokens(token_ids[cursor : cursor + num_empty_slots]) |
cursor += num_empty_slots |
def append_token_id( |
self, |
token_id: int, |
logprobs: Dict[int, float], |
hidden_states: Optional[torch.Tensor] = None, |
finished: bool = False, |
) -> None: |
assert token_id in logprobs |
self._append_tokens_to_blocks([token_id]) |
self.output_logprobs.append(logprobs) |
self.data.append_token_id(token_id, logprobs[token_id]) |
self.data.append_hidden_states(hidden_states) |
self.data.finished = finished |
def get_len(self) -> int: |
return self.data.get_len() |
def get_prompt_len(self) -> int: |
return self.data.get_prompt_len() |
def get_output_len(self) -> int: |
return self.data.get_output_len() |
def get_token_ids(self) -> List[int]: |
return self.data.get_token_ids() |
def get_last_token_id(self) -> int: |
return self.data.get_last_token_id() |
def get_output_token_ids(self) -> List[int]: |
return self.data.output_token_ids |
def get_cumulative_logprob(self) -> float: |
return self.data.cumulative_logprob |
def get_beam_search_score( |
self, |
length_penalty: float = 0.0, |
seq_len: Optional[int] = None, |
eos_token_id: Optional[int] = None, |
) -> float: |
"""Calculate the beam search score with length penalty. |
Adapted from |
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 |
""" |
if seq_len is None: |
seq_len = self.get_len() |
if eos_token_id is not None and self.get_last_token_id() == eos_token_id: |
seq_len -= 1 |
return self.get_cumulative_logprob() / (seq_len**length_penalty) |
def is_finished(self) -> bool: |
return SequenceStatus.is_finished(self.status) |
def fork(self, new_seq_id: int) -> "Sequence": |
new_seq = copy.deepcopy(self) |
new_seq.seq_id = new_seq_id |
return new_seq |
def __repr__(self) -> str: |
return ( |
f"Sequence(seq_id={self.seq_id}, " |
f"status={self.status.name}, " |
f"num_blocks={len(self.logical_token_blocks)})" |
) |
class SequenceGroup: |
"""A group of sequences that are generated from the same prompt. |
Args: |
request_id: The ID of the request. |
seqs: The list of sequences. |
sampling_params: The sampling parameters used to generate the outputs. |
arrival_time: The arrival time of the request. |
""" |
def __init__( |
self, |
request_id: str, |
seqs: List[Sequence], |
sampling_params: SamplingParams, |
arrival_time: float, |
) -> None: |
self.request_id = request_id |
self.seqs_dict = {seq.seq_id: seq for seq in seqs} |
self.sampling_params = sampling_params |
self.arrival_time = arrival_time |
self.prompt_logprobs: Optional[PromptLogprobs] = None |
@property |
def prompt(self) -> str: |
return next(iter(self.seqs_dict.values())).prompt |
@property |
def prompt_token_ids(self) -> List[int]: |
return next(iter(self.seqs_dict.values())).data.prompt_token_ids |
def get_max_num_running_seqs(self) -> int: |
"""The maximum number of sequences running in parallel in the remaining |
lifetime of the request.""" |
if self.sampling_params.use_beam_search: |
return self.sampling_params.best_of |
else: |
if self.sampling_params.best_of > self.num_seqs(): |
return self.sampling_params.best_of |
return self.num_unfinished_seqs() |
def get_seqs( |
self, |
status: Optional[SequenceStatus] = None, |
) -> List[Sequence]: |
if status is None: |
return list(self.seqs_dict.values()) |
else: |
return [seq for seq in self.seqs_dict.values() if seq.status == status] |
def get_unfinished_seqs(self) -> List[Sequence]: |
return [seq for seq in self.seqs_dict.values() if not seq.is_finished()] |
def get_finished_seqs(self) -> List[Sequence]: |
return [seq for seq in self.seqs_dict.values() if seq.is_finished()] |
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: |
return len(self.get_seqs(status)) |
def num_unfinished_seqs(self) -> int: |
return len(self.get_unfinished_seqs()) |
def num_finished_seqs(self) -> int: |
return len(self.get_finished_seqs()) |
def find(self, seq_id: int) -> Sequence: |
if seq_id not in self.seqs_dict: |
raise ValueError(f"Sequence {seq_id} not found.") |
return self.seqs_dict[seq_id] |
def add(self, seq: Sequence) -> None: |
if seq.seq_id in self.seqs_dict: |
raise ValueError(f"Sequence {seq.seq_id} already exists.") |
self.seqs_dict[seq.seq_id] = seq |
def remove(self, seq_id: int) -> None: |
if seq_id not in self.seqs_dict: |
raise ValueError(f"Sequence {seq_id} not found.") |
del self.seqs_dict[seq_id] |
def is_finished(self) -> bool: |
return all(seq.is_finished() for seq in self.get_seqs()) |
def __repr__(self) -> str: |
return ( |
f"SequenceGroup(request_id={self.request_id}, " |
f"sampling_params={self.sampling_params}, " |
f"num_seqs={len(self.seqs_dict)})" |
) |
class SequenceGroupMetadata: |
"""Metadata for a sequence group. Used to create `InputMetadata`. |
Args: |
request_id: The ID of the request. |
is_prompt: Whether the request is at prompt stage. |
seq_data: The sequence data. (Seq id -> sequence data) |
sampling_params: The sampling parameters used to generate the outputs. |
block_tables: The block tables. (Seq id -> list of physical block |
numbers) |
""" |
def __init__( |
self, |
request_id: str, |
is_prompt: bool, |
seq_data: Dict[int, SequenceData], |
sampling_params: SamplingParams, |
block_tables: Dict[int, List[int]], |
) -> None: |
self.request_id = request_id |
self.is_prompt = is_prompt |
self.seq_data = seq_data |
self.sampling_params = sampling_params |
self.block_tables = block_tables |
class SequenceOutput: |
"""The model output associated with a sequence. |
Args: |
parent_seq_id: The ID of the parent sequence (for forking in beam |
search). |
output_token: The output token ID. |
logprobs: The logprobs of the output token. |
(Token id -> logP(x_i+1 | x_0, ..., x_i)) |
""" |
def __init__( |
self, |
parent_seq_id: int, |
output_token: int, |
logprobs: Dict[int, float], |
hidden_states: Optional[torch.Tensor] = None, |
finished: bool = False, |
) -> None: |
self.parent_seq_id = parent_seq_id |
self.output_token = output_token |
self.logprobs = logprobs |
self.finished = finished |
self.hidden_states = hidden_states |
def __repr__(self) -> str: |
return ( |
f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " |
f"output_token={self.output_token}, " |
f"logprobs={self.logprobs})," |
f"finished={self.finished})," |
f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}" |
) |
def __eq__(self, other: object) -> bool: |
if not isinstance(other, SequenceOutput): |
raise NotImplementedError() |
return ( |
self.parent_seq_id == other.parent_seq_id |
and self.output_token == other.output_token |
and self.logprobs == other.logprobs |
) |
class SequenceGroupOutput: |
"""The model output associated with a sequence group.""" |
def __init__( |
self, |
samples: List[SequenceOutput], |
prompt_logprobs: Optional[PromptLogprobs], |
) -> None: |
self.samples = samples |
self.prompt_logprobs = prompt_logprobs |
def __repr__(self) -> str: |
return ( |
f"SequenceGroupOutput(samples={self.samples}, " |
f"prompt_logprobs={self.prompt_logprobs})" |
) |
def __eq__(self, other: object) -> bool: |
if not isinstance(other, SequenceGroupOutput): |
raise NotImplementedError() |
return ( |
self.samples == other.samples |
and self.prompt_logprobs == other.prompt_logprobs |
) |
SamplerOutput = List[SequenceGroupOutput] |