|
"""Sequence and its related classes.""" |
|
|
|
import copy |
|
import enum |
|
from typing import Dict, List, Optional, Union |
|
import torch |
|
from vllm.block import LogicalTokenBlock |
|
from .sampling_params import SamplingParams |
|
|
|
PromptLogprobs = List[Optional[Dict[int, float]]] |
|
SampleLogprobs = List[Dict[int, float]] |
|
|
|
|
|
class SequenceStatus(enum.Enum): |
|
"""Status of a sequence.""" |
|
|
|
WAITING = enum.auto() |
|
RUNNING = enum.auto() |
|
SWAPPED = enum.auto() |
|
FINISHED_STOPPED = enum.auto() |
|
FINISHED_LENGTH_CAPPED = enum.auto() |
|
FINISHED_ABORTED = enum.auto() |
|
FINISHED_IGNORED = enum.auto() |
|
|
|
@staticmethod |
|
def is_finished(status: "SequenceStatus") -> bool: |
|
return status in [ |
|
SequenceStatus.FINISHED_STOPPED, |
|
SequenceStatus.FINISHED_LENGTH_CAPPED, |
|
SequenceStatus.FINISHED_ABORTED, |
|
SequenceStatus.FINISHED_IGNORED, |
|
] |
|
|
|
@staticmethod |
|
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: |
|
if status == SequenceStatus.FINISHED_STOPPED: |
|
finish_reason = "stop" |
|
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: |
|
finish_reason = "length" |
|
elif status == SequenceStatus.FINISHED_ABORTED: |
|
finish_reason = "abort" |
|
elif status == SequenceStatus.FINISHED_IGNORED: |
|
|
|
|
|
|
|
finish_reason = "length" |
|
else: |
|
finish_reason = None |
|
return finish_reason |
|
|
|
|
|
class SequenceData: |
|
"""Data associated with a sequence. |
|
|
|
|
|
Args: |
|
prompt_token_ids: The token IDs of the prompt. |
|
|
|
Attributes: |
|
prompt_token_ids: The token IDs of the prompt. |
|
output_token_ids: The token IDs of the output. |
|
cumulative_logprob: The cumulative log probability of the output. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
prompt_token_ids: List[int], |
|
) -> None: |
|
self.prompt_token_ids = prompt_token_ids |
|
self.output_token_ids: List[int] = [] |
|
self.cumulative_logprob = 0.0 |
|
self.hidden_states: Optional[torch.Tensor] = None |
|
self.finished = False |
|
|
|
def append_token_id(self, token_id: int, logprob: float) -> None: |
|
if isinstance(self.cumulative_logprob, float): |
|
self.cumulative_logprob = [ |
|
0.0, |
|
] * len(logprob) |
|
self.output_token_ids.append(token_id) |
|
for i in range(len(self.cumulative_logprob)): |
|
self.cumulative_logprob[i] += logprob[i] |
|
|
|
def append_hidden_states(self, hidden_states: torch.Tensor) -> None: |
|
if self.hidden_states is None: |
|
self.hidden_states = hidden_states |
|
else: |
|
self.hidden_states = torch.cat([self.hidden_states, hidden_states], dim=0) |
|
|
|
def get_len(self) -> int: |
|
return len(self.output_token_ids) + len(self.prompt_token_ids) |
|
|
|
def get_prompt_len(self) -> int: |
|
return len(self.prompt_token_ids) |
|
|
|
def get_output_len(self) -> int: |
|
return len(self.output_token_ids) |
|
|
|
def get_token_ids(self) -> List[int]: |
|
return self.prompt_token_ids + self.output_token_ids |
|
|
|
def get_last_token_id(self) -> int: |
|
if not self.output_token_ids: |
|
return self.prompt_token_ids[-1] |
|
return self.output_token_ids[-1] |
|
|
|
def __repr__(self) -> str: |
|
return ( |
|
f"SequenceData(" |
|
f"prompt_token_ids={self.prompt_token_ids}, " |
|
f"output_token_ids={self.output_token_ids}, " |
|
f"cumulative_logprob={self.cumulative_logprob}), " |
|
f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}, " |
|
f"finished={self.finished})" |
|
) |
|
|
|
|
|
class Sequence: |
|
"""Stores the data, status, and block information of a sequence. |
|
|
|
Args: |
|
seq_id: The ID of the sequence. |
|
prompt: The prompt of the sequence. |
|
prompt_token_ids: The token IDs of the prompt. |
|
block_size: The block size of the sequence. Should be the same as the |
|
block size used by the block manager and cache engine. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
seq_id: int, |
|
prompt: str, |
|
prompt_token_ids: List[int], |
|
block_size: int, |
|
) -> None: |
|
self.seq_id = seq_id |
|
self.prompt = prompt |
|
self.block_size = block_size |
|
|
|
self.data = SequenceData(prompt_token_ids) |
|
self.output_logprobs: SampleLogprobs = [] |
|
self.output_text = "" |
|
|
|
self.logical_token_blocks: List[LogicalTokenBlock] = [] |
|
|
|
self._append_tokens_to_blocks(prompt_token_ids) |
|
self.status = SequenceStatus.WAITING |
|
|
|
|
|
self.prefix_offset = 0 |
|
self.read_offset = 0 |
|
|
|
self.tokens: Optional[List[str]] = None |
|
|
|
def _append_logical_block(self) -> None: |
|
block = LogicalTokenBlock( |
|
block_number=len(self.logical_token_blocks), |
|
block_size=self.block_size, |
|
) |
|
self.logical_token_blocks.append(block) |
|
|
|
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: |
|
cursor = 0 |
|
while cursor < len(token_ids): |
|
if not self.logical_token_blocks: |
|
self._append_logical_block() |
|
|
|
last_block = self.logical_token_blocks[-1] |
|
if last_block.is_full(): |
|
self._append_logical_block() |
|
last_block = self.logical_token_blocks[-1] |
|
|
|
num_empty_slots = last_block.get_num_empty_slots() |
|
last_block.append_tokens(token_ids[cursor : cursor + num_empty_slots]) |
|
cursor += num_empty_slots |
|
|
|
def append_token_id( |
|
self, |
|
token_id: int, |
|
logprobs: Dict[int, float], |
|
hidden_states: Optional[torch.Tensor] = None, |
|
finished: bool = False, |
|
) -> None: |
|
assert token_id in logprobs |
|
self._append_tokens_to_blocks([token_id]) |
|
self.output_logprobs.append(logprobs) |
|
self.data.append_token_id(token_id, logprobs[token_id]) |
|
self.data.append_hidden_states(hidden_states) |
|
self.data.finished = finished |
|
|
|
def get_len(self) -> int: |
|
return self.data.get_len() |
|
|
|
def get_prompt_len(self) -> int: |
|
return self.data.get_prompt_len() |
|
|
|
def get_output_len(self) -> int: |
|
return self.data.get_output_len() |
|
|
|
def get_token_ids(self) -> List[int]: |
|
return self.data.get_token_ids() |
|
|
|
def get_last_token_id(self) -> int: |
|
return self.data.get_last_token_id() |
|
|
|
def get_output_token_ids(self) -> List[int]: |
|
return self.data.output_token_ids |
|
|
|
def get_cumulative_logprob(self) -> float: |
|
return self.data.cumulative_logprob |
|
|
|
def get_beam_search_score( |
|
self, |
|
length_penalty: float = 0.0, |
|
seq_len: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
) -> float: |
|
"""Calculate the beam search score with length penalty. |
|
|
|
Adapted from |
|
|
|
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 |
|
""" |
|
if seq_len is None: |
|
seq_len = self.get_len() |
|
|
|
|
|
if eos_token_id is not None and self.get_last_token_id() == eos_token_id: |
|
seq_len -= 1 |
|
return self.get_cumulative_logprob() / (seq_len**length_penalty) |
|
|
|
def is_finished(self) -> bool: |
|
return SequenceStatus.is_finished(self.status) |
|
|
|
def fork(self, new_seq_id: int) -> "Sequence": |
|
new_seq = copy.deepcopy(self) |
|
new_seq.seq_id = new_seq_id |
|
return new_seq |
|
|
|
def __repr__(self) -> str: |
|
return ( |
|
f"Sequence(seq_id={self.seq_id}, " |
|
f"status={self.status.name}, " |
|
f"num_blocks={len(self.logical_token_blocks)})" |
|
) |
|
|
|
|
|
class SequenceGroup: |
|
"""A group of sequences that are generated from the same prompt. |
|
|
|
Args: |
|
request_id: The ID of the request. |
|
seqs: The list of sequences. |
|
sampling_params: The sampling parameters used to generate the outputs. |
|
arrival_time: The arrival time of the request. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
request_id: str, |
|
seqs: List[Sequence], |
|
sampling_params: SamplingParams, |
|
arrival_time: float, |
|
) -> None: |
|
self.request_id = request_id |
|
self.seqs_dict = {seq.seq_id: seq for seq in seqs} |
|
self.sampling_params = sampling_params |
|
self.arrival_time = arrival_time |
|
self.prompt_logprobs: Optional[PromptLogprobs] = None |
|
|
|
@property |
|
def prompt(self) -> str: |
|
|
|
|
|
return next(iter(self.seqs_dict.values())).prompt |
|
|
|
@property |
|
def prompt_token_ids(self) -> List[int]: |
|
|
|
|
|
return next(iter(self.seqs_dict.values())).data.prompt_token_ids |
|
|
|
def get_max_num_running_seqs(self) -> int: |
|
"""The maximum number of sequences running in parallel in the remaining |
|
lifetime of the request.""" |
|
if self.sampling_params.use_beam_search: |
|
|
|
|
|
return self.sampling_params.best_of |
|
else: |
|
if self.sampling_params.best_of > self.num_seqs(): |
|
|
|
|
|
|
|
return self.sampling_params.best_of |
|
|
|
|
|
return self.num_unfinished_seqs() |
|
|
|
def get_seqs( |
|
self, |
|
status: Optional[SequenceStatus] = None, |
|
) -> List[Sequence]: |
|
if status is None: |
|
return list(self.seqs_dict.values()) |
|
else: |
|
return [seq for seq in self.seqs_dict.values() if seq.status == status] |
|
|
|
def get_unfinished_seqs(self) -> List[Sequence]: |
|
return [seq for seq in self.seqs_dict.values() if not seq.is_finished()] |
|
|
|
def get_finished_seqs(self) -> List[Sequence]: |
|
return [seq for seq in self.seqs_dict.values() if seq.is_finished()] |
|
|
|
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: |
|
return len(self.get_seqs(status)) |
|
|
|
def num_unfinished_seqs(self) -> int: |
|
return len(self.get_unfinished_seqs()) |
|
|
|
def num_finished_seqs(self) -> int: |
|
return len(self.get_finished_seqs()) |
|
|
|
def find(self, seq_id: int) -> Sequence: |
|
if seq_id not in self.seqs_dict: |
|
raise ValueError(f"Sequence {seq_id} not found.") |
|
return self.seqs_dict[seq_id] |
|
|
|
def add(self, seq: Sequence) -> None: |
|
if seq.seq_id in self.seqs_dict: |
|
raise ValueError(f"Sequence {seq.seq_id} already exists.") |
|
self.seqs_dict[seq.seq_id] = seq |
|
|
|
def remove(self, seq_id: int) -> None: |
|
if seq_id not in self.seqs_dict: |
|
raise ValueError(f"Sequence {seq_id} not found.") |
|
del self.seqs_dict[seq_id] |
|
|
|
def is_finished(self) -> bool: |
|
return all(seq.is_finished() for seq in self.get_seqs()) |
|
|
|
def __repr__(self) -> str: |
|
return ( |
|
f"SequenceGroup(request_id={self.request_id}, " |
|
f"sampling_params={self.sampling_params}, " |
|
f"num_seqs={len(self.seqs_dict)})" |
|
) |
|
|
|
|
|
class SequenceGroupMetadata: |
|
"""Metadata for a sequence group. Used to create `InputMetadata`. |
|
|
|
|
|
Args: |
|
request_id: The ID of the request. |
|
is_prompt: Whether the request is at prompt stage. |
|
seq_data: The sequence data. (Seq id -> sequence data) |
|
sampling_params: The sampling parameters used to generate the outputs. |
|
block_tables: The block tables. (Seq id -> list of physical block |
|
numbers) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
request_id: str, |
|
is_prompt: bool, |
|
seq_data: Dict[int, SequenceData], |
|
sampling_params: SamplingParams, |
|
block_tables: Dict[int, List[int]], |
|
) -> None: |
|
self.request_id = request_id |
|
self.is_prompt = is_prompt |
|
self.seq_data = seq_data |
|
self.sampling_params = sampling_params |
|
self.block_tables = block_tables |
|
|
|
|
|
class SequenceOutput: |
|
"""The model output associated with a sequence. |
|
|
|
Args: |
|
parent_seq_id: The ID of the parent sequence (for forking in beam |
|
search). |
|
output_token: The output token ID. |
|
logprobs: The logprobs of the output token. |
|
(Token id -> logP(x_i+1 | x_0, ..., x_i)) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
parent_seq_id: int, |
|
output_token: int, |
|
logprobs: Dict[int, float], |
|
hidden_states: Optional[torch.Tensor] = None, |
|
finished: bool = False, |
|
) -> None: |
|
self.parent_seq_id = parent_seq_id |
|
self.output_token = output_token |
|
self.logprobs = logprobs |
|
self.finished = finished |
|
self.hidden_states = hidden_states |
|
|
|
def __repr__(self) -> str: |
|
return ( |
|
f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " |
|
f"output_token={self.output_token}, " |
|
f"logprobs={self.logprobs})," |
|
f"finished={self.finished})," |
|
f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}" |
|
) |
|
|
|
def __eq__(self, other: object) -> bool: |
|
if not isinstance(other, SequenceOutput): |
|
raise NotImplementedError() |
|
return ( |
|
self.parent_seq_id == other.parent_seq_id |
|
and self.output_token == other.output_token |
|
and self.logprobs == other.logprobs |
|
) |
|
|
|
|
|
class SequenceGroupOutput: |
|
"""The model output associated with a sequence group.""" |
|
|
|
def __init__( |
|
self, |
|
samples: List[SequenceOutput], |
|
prompt_logprobs: Optional[PromptLogprobs], |
|
) -> None: |
|
self.samples = samples |
|
self.prompt_logprobs = prompt_logprobs |
|
|
|
def __repr__(self) -> str: |
|
return ( |
|
f"SequenceGroupOutput(samples={self.samples}, " |
|
f"prompt_logprobs={self.prompt_logprobs})" |
|
) |
|
|
|
def __eq__(self, other: object) -> bool: |
|
if not isinstance(other, SequenceGroupOutput): |
|
raise NotImplementedError() |
|
return ( |
|
self.samples == other.samples |
|
and self.prompt_logprobs == other.prompt_logprobs |
|
) |
|
|
|
|
|
|
|
|
|
SamplerOutput = List[SequenceGroupOutput] |
|
|