Spaces:
Sleeping
Sleeping
from dataclasses import dataclass | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
from vllm.sampling_params import SamplingParams, SamplingType | |
from vllm.sequence import SequenceData | |
from vllm.utils import in_wsl | |
_SAMPLING_EPS = 1e-5 | |
class SamplingMetadata: | |
"""Metadata for input sequences. Used in sampler. | |
Args: | |
seq_groups: List of (seq_ids, sampling_params). | |
seq_data: Seq_id -> SequenceData. | |
prompt_lens: Lengths of prompts. | |
selected_token_indices: Token indices selected for sampling. | |
categorized_sample_indices: SamplingType -> token indices to sample. | |
perform_sampling: Whether to perform sampling. This option is used to | |
make the sampling only happens in the driver worker, and disable | |
sampling in other worker processes. | |
""" | |
def __init__( | |
self, | |
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]], | |
seq_data: Optional[Dict[int, SequenceData]], | |
prompt_lens: Optional[List[int]], | |
selected_token_indices: torch.Tensor, | |
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]], | |
perform_sampling: bool = True, | |
) -> None: | |
self.seq_groups = seq_groups | |
self.seq_data = seq_data | |
self.prompt_lens = prompt_lens | |
self.selected_token_indices = selected_token_indices | |
self.categorized_sample_indices = categorized_sample_indices | |
self.perform_sampling = perform_sampling | |
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0 | |
def __repr__(self) -> str: | |
return ( | |
"SamplingMetadata(" | |
f"seq_groups={self.seq_groups}, " | |
f"seq_data={self.seq_data}, " | |
f"prompt_lens={self.prompt_lens}, " | |
f"selected_token_indices={self.selected_token_indices}, " | |
f"categorized_sample_indices={self.categorized_sample_indices}), " | |
f"perform_sampling={self.perform_sampling})") | |
class SamplingTensors: | |
"""Tensors for sampling.""" | |
temperatures: torch.Tensor | |
top_ps: torch.Tensor | |
top_ks: torch.Tensor | |
min_ps: torch.Tensor | |
presence_penalties: torch.Tensor | |
frequency_penalties: torch.Tensor | |
repetition_penalties: torch.Tensor | |
prompt_tokens: torch.Tensor | |
output_tokens: torch.Tensor | |
def from_sampling_metadata( | |
cls, sampling_metadata: "SamplingMetadata", vocab_size: int, | |
device: torch.device, | |
dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]: | |
prompt_tokens: List[List[int]] = [] | |
output_tokens: List[List[int]] = [] | |
top_ks: List[int] = [] | |
temperatures: List[float] = [] | |
top_ps: List[float] = [] | |
min_ps: List[float] = [] | |
presence_penalties: List[float] = [] | |
frequency_penalties: List[float] = [] | |
repetition_penalties: List[float] = [] | |
do_penalties = False | |
do_top_p_top_k = False | |
do_min_p = False | |
for i, seq_group in enumerate(sampling_metadata.seq_groups): | |
seq_ids, sampling_params = seq_group | |
temperature = sampling_params.temperature | |
p = sampling_params.presence_penalty | |
f = sampling_params.frequency_penalty | |
r = sampling_params.repetition_penalty | |
top_p = sampling_params.top_p | |
min_p = sampling_params.min_p | |
# k should not be greater than the vocab size. | |
top_k = min(sampling_params.top_k, vocab_size) | |
top_k = vocab_size if top_k == -1 else top_k | |
if temperature < _SAMPLING_EPS: | |
# NOTE: Zero temperature means deterministic sampling | |
# (i.e., greedy sampling or beam search). | |
# Set the temperature to 1 to avoid division by zero. | |
temperature = 1.0 | |
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS | |
or top_k != vocab_size): | |
do_top_p_top_k = True | |
if not do_min_p and min_p > _SAMPLING_EPS: | |
do_min_p = True | |
if not do_penalties and (abs(p) >= _SAMPLING_EPS | |
or abs(f) >= _SAMPLING_EPS | |
or abs(r - 1.0) >= _SAMPLING_EPS): | |
do_penalties = True | |
if (i < sampling_metadata.num_prompts | |
and sampling_params.prompt_logprobs is not None): | |
# For tokens in the prompt that we only need to get their logprobs | |
prompt_len = sampling_metadata.prompt_lens[i] | |
temperatures += [temperature] * (prompt_len - 1) | |
top_ps += [top_p] * (prompt_len - 1) | |
top_ks += [top_k] * (prompt_len - 1) | |
min_ps += [min_p] * (prompt_len - 1) | |
presence_penalties += [0] * (prompt_len - 1) | |
frequency_penalties += [0] * (prompt_len - 1) | |
repetition_penalties += [1] * (prompt_len - 1) | |
prompt_tokens.extend([] for _ in range(prompt_len - 1)) | |
output_tokens.extend([] for _ in range(prompt_len - 1)) | |
for seq_id in seq_ids: | |
seq_data = sampling_metadata.seq_data[seq_id] | |
prompt_tokens.append(seq_data.prompt_token_ids) | |
output_tokens.append(seq_data.output_token_ids) | |
temperatures += [temperature] * len(seq_ids) | |
top_ps += [top_p] * len(seq_ids) | |
top_ks += [top_k] * len(seq_ids) | |
min_ps += [min_p] * len(seq_ids) | |
presence_penalties += [p] * len(seq_ids) | |
frequency_penalties += [f] * len(seq_ids) | |
repetition_penalties += [r] * len(seq_ids) | |
sampling_tensors = SamplingTensors.from_lists( | |
temperatures, top_ps, top_ks, min_ps, presence_penalties, | |
frequency_penalties, repetition_penalties, prompt_tokens, | |
output_tokens, vocab_size, device, dtype) | |
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) | |
def from_lists(cls, temperatures: List[float], top_ps: List[float], | |
top_ks: List[int], min_ps: List[float], | |
presence_penalties: List[float], | |
frequency_penalties: List[float], | |
repetition_penalties: List[float], | |
prompt_tokens: List[List[int]], | |
output_tokens: List[List[int]], vocab_size: int, | |
device: torch.device, | |
dtype: torch.dtype) -> "SamplingTensors": | |
# Note that the performance will be very bad without | |
# pinned memory. | |
pin_memory = not in_wsl() | |
prompt_max_len = max(len(tokens) for tokens in prompt_tokens) | |
prompt_padded_tokens = [ | |
tokens + [vocab_size] * (prompt_max_len - len(tokens)) | |
for tokens in prompt_tokens | |
] | |
output_max_len = max(len(tokens) for tokens in output_tokens) | |
output_padded_tokens = [ | |
tokens + [vocab_size] * (output_max_len - len(tokens)) | |
for tokens in output_tokens | |
] | |
temperatures_t = torch.tensor( | |
temperatures, | |
device="cpu", | |
dtype=dtype, | |
pin_memory=pin_memory, | |
) | |
top_ps_t = torch.tensor( | |
top_ps, | |
device="cpu", | |
dtype=dtype, | |
pin_memory=pin_memory, | |
) | |
min_ps_t = torch.tensor( | |
min_ps, | |
device="cpu", | |
dtype=dtype, | |
pin_memory=pin_memory, | |
) | |
presence_penalties_t = torch.tensor( | |
presence_penalties, | |
device="cpu", | |
dtype=dtype, | |
pin_memory=pin_memory, | |
) | |
frequency_penalties_t = torch.tensor( | |
frequency_penalties, | |
device="cpu", | |
dtype=dtype, | |
pin_memory=pin_memory, | |
) | |
repetition_penalties_t = torch.tensor( | |
repetition_penalties, | |
device="cpu", | |
dtype=dtype, | |
pin_memory=pin_memory, | |
) | |
top_ks_t = torch.tensor( | |
top_ks, | |
device="cpu", | |
dtype=torch.int, | |
pin_memory=pin_memory, | |
) | |
prompt_tensor = torch.tensor( | |
prompt_padded_tokens, | |
device="cpu", | |
dtype=torch.long, | |
pin_memory=pin_memory, | |
) | |
output_tensor = torch.tensor( | |
output_padded_tokens, | |
device="cpu", | |
dtype=torch.long, | |
pin_memory=pin_memory, | |
) | |
# Because the memory is pinned, we can do non-blocking | |
# transfer to device. | |
return cls( | |
temperatures=temperatures_t.to(device=device, non_blocking=True), | |
top_ps=top_ps_t.to(device=device, non_blocking=True), | |
top_ks=top_ks_t.to(device=device, non_blocking=True), | |
min_ps=min_ps_t.to(device=device, non_blocking=True), | |
presence_penalties=presence_penalties_t.to(device=device, | |
non_blocking=True), | |
frequency_penalties=frequency_penalties_t.to(device=device, | |
non_blocking=True), | |
repetition_penalties=repetition_penalties_t.to(device=device, | |
non_blocking=True), | |
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), | |
output_tokens=output_tensor.to(device=device, non_blocking=True), | |
) | |