certifaier / vllm /model_executor /sampling_metadata.py
bsmit1659's picture
Adding vllm package
ca1ecab
raw
history blame
9.68 kB
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
from vllm.utils import in_wsl
_SAMPLING_EPS = 1e-5
class SamplingMetadata:
"""Metadata for input sequences. Used in sampler.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indices to sample.
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
"""
def __init__(
self,
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
seq_data: Optional[Dict[int, SequenceData]],
prompt_lens: Optional[List[int]],
selected_token_indices: torch.Tensor,
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
perform_sampling: bool = True,
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.perform_sampling = perform_sampling
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
def __repr__(self) -> str:
return (
"SamplingMetadata("
f"seq_groups={self.seq_groups}, "
f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens}, "
f"selected_token_indices={self.selected_token_indices}, "
f"categorized_sample_indices={self.categorized_sample_indices}), "
f"perform_sampling={self.perform_sampling})")
@dataclass
class SamplingTensors:
"""Tensors for sampling."""
temperatures: torch.Tensor
top_ps: torch.Tensor
top_ks: torch.Tensor
min_ps: torch.Tensor
presence_penalties: torch.Tensor
frequency_penalties: torch.Tensor
repetition_penalties: torch.Tensor
prompt_tokens: torch.Tensor
output_tokens: torch.Tensor
@classmethod
def from_sampling_metadata(
cls, sampling_metadata: "SamplingMetadata", vocab_size: int,
device: torch.device,
dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]:
prompt_tokens: List[List[int]] = []
output_tokens: List[List[int]] = []
top_ks: List[int] = []
temperatures: List[float] = []
top_ps: List[float] = []
min_ps: List[float] = []
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
do_penalties = False
do_top_p_top_k = False
do_min_p = False
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
r = sampling_params.repetition_penalty
top_p = sampling_params.top_p
min_p = sampling_params.min_p
# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
top_k = vocab_size if top_k == -1 else top_k
if temperature < _SAMPLING_EPS:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature = 1.0
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
or top_k != vocab_size):
do_top_p_top_k = True
if not do_min_p and min_p > _SAMPLING_EPS:
do_min_p = True
if not do_penalties and (abs(p) >= _SAMPLING_EPS
or abs(f) >= _SAMPLING_EPS
or abs(r - 1.0) >= _SAMPLING_EPS):
do_penalties = True
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get their logprobs
prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1)
min_ps += [min_p] * (prompt_len - 1)
presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1)
repetition_penalties += [1] * (prompt_len - 1)
prompt_tokens.extend([] for _ in range(prompt_len - 1))
output_tokens.extend([] for _ in range(prompt_len - 1))
for seq_id in seq_ids:
seq_data = sampling_metadata.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
temperatures += [temperature] * len(seq_ids)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
min_ps += [min_p] * len(seq_ids)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,
frequency_penalties, repetition_penalties, prompt_tokens,
output_tokens, vocab_size, device, dtype)
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
@classmethod
def from_lists(cls, temperatures: List[float], top_ps: List[float],
top_ks: List[int], min_ps: List[float],
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
prompt_tokens: List[List[int]],
output_tokens: List[List[int]], vocab_size: int,
device: torch.device,
dtype: torch.dtype) -> "SamplingTensors":
# Note that the performance will be very bad without
# pinned memory.
pin_memory = not in_wsl()
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
for tokens in prompt_tokens
]
output_max_len = max(len(tokens) for tokens in output_tokens)
output_padded_tokens = [
tokens + [vocab_size] * (output_max_len - len(tokens))
for tokens in output_tokens
]
temperatures_t = torch.tensor(
temperatures,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
top_ps_t = torch.tensor(
top_ps,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
min_ps_t = torch.tensor(
min_ps,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
presence_penalties_t = torch.tensor(
presence_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
frequency_penalties_t = torch.tensor(
frequency_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
repetition_penalties_t = torch.tensor(
repetition_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
top_ks_t = torch.tensor(
top_ks,
device="cpu",
dtype=torch.int,
pin_memory=pin_memory,
)
prompt_tensor = torch.tensor(
prompt_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
output_tensor = torch.tensor(
output_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
# Because the memory is pinned, we can do non-blocking
# transfer to device.
return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True),
top_ks=top_ks_t.to(device=device, non_blocking=True),
min_ps=min_ps_t.to(device=device, non_blocking=True),
presence_penalties=presence_penalties_t.to(device=device,
non_blocking=True),
frequency_penalties=frequency_penalties_t.to(device=device,
non_blocking=True),
repetition_penalties=repetition_penalties_t.to(device=device,
non_blocking=True),
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
output_tokens=output_tensor.to(device=device, non_blocking=True),
)