Spaces:
Runtime error
Runtime error
import torch | |
import torch.distributed | |
from torch.nn import functional as F | |
from dataclasses import dataclass | |
from opentelemetry import trace | |
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel | |
from typing import Optional, Tuple, List, Type, Union, Dict | |
from text_generation_server.models import Model | |
from text_generation_server.models.types import ( | |
Batch, | |
PrefillTokens, | |
Generation, | |
GeneratedText, | |
) | |
from text_generation_server.pb import generate_pb2 | |
from text_generation_server.utils import ( | |
NextTokenChooser, | |
StoppingCriteria, | |
Sampling, | |
) | |
tracer = trace.get_tracer(__name__) | |
class FlashCausalLMBatch(Batch): | |
batch_id: int | |
requests: List[generate_pb2.Request] | |
# request id -> idx in list mapping | |
requests_idx_mapping: Dict[int, int] | |
# Decoder values | |
input_ids: List[torch.Tensor] | |
position_ids: List[torch.Tensor] | |
# cumulative sequence lengths | |
cu_seqlens: List[int] | |
max_seqlen: int | |
past_key_values: Optional[Union[torch.Tensor, List[torch.Tensor]]] | |
# All tokens | |
all_input_ids: List[List[int]] | |
all_input_ids_tensor: List[torch.Tensor] | |
# Lengths of all generations present in the batch | |
input_lengths: List[int] | |
offsets: List[Optional[int]] | |
token_offsets: List[Optional[int]] | |
# Generation helpers | |
next_token_choosers: List[NextTokenChooser] | |
stopping_criterias: List[StoppingCriteria] | |
# Constant shared tensor, ref here just so that it's accessible in concatentate() | |
past_pad: Optional[torch.Tensor] | |
# Maximum number of tokens this batch will grow to | |
max_tokens: int | |
def to_pb(self) -> generate_pb2.Batch: | |
return generate_pb2.Batch( | |
id=self.batch_id, | |
requests=self.requests, | |
size=len(self), | |
max_tokens=self.max_tokens, | |
) | |
def from_pb( | |
cls, | |
pb: generate_pb2.Batch, | |
tokenizer: PreTrainedTokenizerBase, | |
device: torch.device, | |
) -> "FlashCausalLMBatch": | |
input_ids = [] | |
position_ids = [] | |
cu_seqlens = [0] | |
max_seqlen = 0 | |
input_lengths = [] | |
offsets = [] | |
token_offsets = [] | |
all_input_ids = [] | |
all_input_ids_tensor = [] | |
requests_idx_mapping = {} | |
next_token_choosers = [] | |
stopping_criterias = [] | |
# Cumulative length | |
cumulative_length = 0 | |
max_tokens = 0 | |
# Parse batch | |
for i, r in enumerate(pb.requests): | |
# request id -> idx in list mapping | |
requests_idx_mapping[r.id] = i | |
tokenized_input = tokenizer( | |
r.inputs, truncation=True, max_length=r.truncate | |
)["input_ids"] | |
input_length = len(tokenized_input) | |
max_seqlen = max(max_seqlen, input_length) | |
input_lengths.append(input_length) | |
offsets.append(None) | |
token_offsets.append(None) | |
all_input_ids.append(tokenized_input) | |
tokenized_input = torch.tensor(tokenized_input, device=device) | |
input_ids.append(tokenized_input) | |
# Position ids | |
position_ids.append( | |
torch.arange(0, input_length, dtype=torch.int32, device=device) | |
) | |
# Add cumulative lengths of all previous inputs | |
cu_seqlens.append(cumulative_length + input_length) | |
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) | |
stopping_criteria = StoppingCriteria.from_pb( | |
r.stopping_parameters, tokenizer | |
) | |
max_new_tokens = stopping_criteria.max_new_tokens | |
stopping_criterias.append(stopping_criteria) | |
all_input_ids_tensor.append( | |
F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens)) | |
) | |
# Update | |
cumulative_length += input_length | |
max_tokens += input_length + max_new_tokens | |
return cls( | |
batch_id=pb.id, | |
requests=pb.requests, | |
requests_idx_mapping=requests_idx_mapping, | |
input_ids=input_ids, | |
position_ids=position_ids, | |
cu_seqlens=cu_seqlens, | |
max_seqlen=max_seqlen, | |
past_key_values=None, | |
input_lengths=input_lengths, | |
offsets=offsets, | |
token_offsets=token_offsets, | |
all_input_ids=all_input_ids, | |
all_input_ids_tensor=all_input_ids_tensor, | |
next_token_choosers=next_token_choosers, | |
stopping_criterias=stopping_criterias, | |
past_pad=None, | |
max_tokens=max_tokens, | |
) | |
def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch": | |
if len(requests) == 0: | |
raise ValueError("Batch must have at least one request") | |
# We assume that if len(requests) == len(self) then the requests are the same | |
if len(requests) == len(self): | |
return self | |
single_request = len(requests) == 1 | |
# Cumulative length | |
cumulative_length = 0 | |
# New values after filtering | |
requests_idx_mapping = {} | |
input_ids = [] | |
position_ids = [] | |
cu_seqlens = [0] | |
max_seqlen = 0 | |
past_key_values = [] | |
all_input_ids = [] | |
all_input_ids_tensor = [] | |
input_lengths = [] | |
offsets = [] | |
token_offsets = [] | |
next_token_choosers = [] | |
stopping_criterias = [] | |
max_tokens = 0 | |
for i, r in enumerate(requests): | |
idx = self.requests_idx_mapping[r.id] | |
requests_idx_mapping[r.id] = i | |
# Get length | |
request_input_length = self.input_lengths[idx] | |
input_ids.append(self.input_ids[idx]) | |
position_ids.append(self.position_ids[idx]) | |
cu_seqlens.append(cumulative_length + request_input_length) | |
max_seqlen = max(max_seqlen, request_input_length) | |
# True index for past | |
past_key_values.append(self.past_key_values[2 * idx]) | |
if not single_request: | |
# Add one padding | |
past_key_values.append(self.past_pad) | |
all_input_ids.append(self.all_input_ids[idx]) | |
all_input_ids_tensor.append(self.all_input_ids_tensor[idx]) | |
input_lengths.append(request_input_length) | |
offsets.append(self.offsets[idx]) | |
token_offsets.append(self.token_offsets[idx]) | |
next_token_choosers.append(self.next_token_choosers[idx]) | |
stopping_criteria = self.stopping_criterias[idx] | |
stopping_criterias.append(stopping_criteria) | |
cumulative_length += request_input_length | |
max_tokens += request_input_length + ( | |
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens | |
) | |
if single_request: | |
# Preallocate tensor for bs = 1 case | |
past_key_values = torch.nn.functional.pad( | |
past_key_values[0], | |
( | |
0, | |
0, | |
0, | |
0, | |
0, | |
0, | |
0, | |
stopping_criterias[0].max_new_tokens | |
- stopping_criterias[0].current_tokens, | |
), | |
) | |
return FlashCausalLMBatch( | |
batch_id=self.batch_id, | |
past_pad=self.past_pad, | |
requests=requests, | |
requests_idx_mapping=requests_idx_mapping, | |
input_ids=input_ids, | |
position_ids=position_ids, | |
cu_seqlens=cu_seqlens, | |
max_seqlen=max_seqlen, | |
past_key_values=past_key_values, | |
input_lengths=input_lengths, | |
offsets=offsets, | |
token_offsets=token_offsets, | |
all_input_ids=all_input_ids, | |
all_input_ids_tensor=all_input_ids_tensor, | |
next_token_choosers=next_token_choosers, | |
stopping_criterias=stopping_criterias, | |
max_tokens=max_tokens, | |
) | |
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": | |
# Batch attributes | |
requests = [] | |
requests_idx_mapping = {} | |
input_ids = [] | |
position_ids = [] | |
cu_seqlens = [0] | |
max_seqlen = 0 | |
past_key_values = [] | |
all_input_ids = [] | |
all_input_ids_tensor = [] | |
input_lengths = [] | |
offsets = [] | |
token_offsets = [] | |
next_token_choosers = [] | |
stopping_criterias = [] | |
# Cumulative length | |
cumulative_batch_size = 0 | |
cumulative_length = 0 | |
max_tokens = 0 | |
for i, batch in enumerate(batches): | |
requests.extend(batch.requests) | |
if i == 0: | |
requests_idx_mapping = batch.requests_idx_mapping | |
else: | |
# We need to offset the mapping for each batch by the cumulative batch size | |
for k, v in batch.requests_idx_mapping.items(): | |
requests_idx_mapping[k] = v + cumulative_batch_size | |
input_ids.extend(batch.input_ids) | |
position_ids.extend(batch.position_ids) | |
# Add cumulative lengths of all previous inputs | |
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]]) | |
max_seqlen = max(max_seqlen, batch.max_seqlen) | |
if len(batch) != 1: | |
past_key_values.extend(batch.past_key_values) | |
else: | |
# past was pre-allocated for this batch | |
# We need to slice to remove the padding | |
past_key_values.append( | |
batch.past_key_values[:, : batch.input_lengths[0]] | |
) | |
# Add one padding | |
past_key_values.append(batch.past_pad) | |
all_input_ids.extend(batch.all_input_ids) | |
all_input_ids_tensor.extend(batch.all_input_ids_tensor) | |
input_lengths.extend(batch.input_lengths) | |
offsets.extend(batch.offsets) | |
token_offsets.extend(batch.token_offsets) | |
next_token_choosers.extend(batch.next_token_choosers) | |
stopping_criterias.extend(batch.stopping_criterias) | |
# Update | |
cumulative_length += batch.cu_seqlens[-1] | |
cumulative_batch_size += len(batch) | |
max_tokens += batch.max_tokens | |
return FlashCausalLMBatch( | |
batch_id=batches[0].batch_id, | |
past_pad=batches[0].past_pad, | |
requests=requests, | |
requests_idx_mapping=requests_idx_mapping, | |
input_ids=input_ids, | |
position_ids=position_ids, | |
cu_seqlens=cu_seqlens, | |
max_seqlen=max_seqlen, | |
past_key_values=past_key_values, | |
input_lengths=input_lengths, | |
offsets=offsets, | |
token_offsets=token_offsets, | |
all_input_ids=all_input_ids, | |
all_input_ids_tensor=all_input_ids_tensor, | |
next_token_choosers=next_token_choosers, | |
stopping_criterias=stopping_criterias, | |
max_tokens=max_tokens, | |
) | |
def __len__(self): | |
return len(self.requests) | |
class FlashCausalLM(Model): | |
def __init__( | |
self, | |
model_cls: Type[PreTrainedModel], | |
model_id: str, | |
revision: Optional[str] = None, | |
quantize: bool = False, | |
decode_buffer: int = 3, | |
): | |
self.past_pad = None | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
else: | |
raise NotImplementedError("FlashCausalLM is only available on GPU") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_id, revision=revision, padding_side="left", truncation_side="left" | |
) | |
self.model = ( | |
model_cls.from_pretrained( | |
model_id, | |
revision=revision, | |
torch_dtype=dtype, | |
load_in_8bit=quantize, | |
) | |
.eval() | |
.to(device) | |
) | |
super(FlashCausalLM, self).__init__( | |
tokenizer=tokenizer, | |
requires_padding=False, | |
dtype=dtype, | |
device=device, | |
decode_buffer=decode_buffer, | |
) | |
def batch_type(self) -> Type[FlashCausalLMBatch]: | |
return FlashCausalLMBatch | |
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: | |
return self.tokenizer.decode( | |
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False | |
) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
position_ids: torch.Tensor, | |
cu_seqlens: torch.Tensor, | |
max_s: int, | |
past_key_values: Optional = None, | |
pre_allocate_past_size: Optional[int] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# Model Forward | |
return self.model.forward( | |
input_ids=input_ids, | |
position_ids=position_ids, | |
cu_seqlens=cu_seqlens, | |
max_s=max_s, | |
past_key_values=past_key_values, | |
pre_allocate_past_size=pre_allocate_past_size, | |
) | |
def generate_token( | |
self, batch: FlashCausalLMBatch | |
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: | |
# Shortcut when batch_size == 1 | |
if len(batch) == 1: | |
input_ids = batch.input_ids[0].view(-1) | |
# No need to slice as flash attention will take care of it with cu_seqlens | |
past_key_values = batch.past_key_values | |
else: | |
# Concatenate tensors | |
input_ids = torch.cat(batch.input_ids).view(-1) | |
past_key_values = ( | |
torch.cat(batch.past_key_values, dim=1) | |
if batch.past_key_values is not None | |
else None | |
) | |
# if prefill and bs == 1 | |
if past_key_values is None and len(batch) == 1: | |
# Ask to pre-allocate kv to its max size | |
# == number of tokens + max_new_tokens | |
pre_allocate_past_size = ( | |
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens | |
) | |
else: | |
pre_allocate_past_size = None | |
# Concatenate when prefill, torch.tensor when decode | |
position_ids = ( | |
torch.tensor(batch.position_ids, device=self.device) | |
if batch.past_key_values is not None | |
else torch.cat(batch.position_ids) | |
) | |
cu_seqlens = torch.tensor( | |
batch.cu_seqlens, device=self.device, dtype=torch.int32 | |
) | |
out, present = self.forward( | |
input_ids, | |
position_ids, | |
cu_seqlens, | |
batch.max_seqlen, | |
past_key_values, | |
pre_allocate_past_size, | |
) | |
# Initialize past_key_values in prefill | |
if batch.past_key_values is None: | |
# Initialize past padding tensor | |
if self.past_pad is None: | |
self.past_pad = present.new_zeros( | |
present.shape[0], 1, *present.shape[2:] | |
) | |
# Set in batch in case it needs to be used later in concatenate() | |
batch.past_pad = self.past_pad | |
if len(batch) == 1: | |
# present is already pre-padded | |
batch.past_key_values = present | |
else: | |
# Add padding after each sequence | |
# This will have the correct shape after the final past_key_values concatenation before the model | |
# forward | |
batch.past_key_values = [None, self.past_pad] * len(batch) | |
# Cumulative length | |
cumulative_length = 0 | |
# Results | |
generations: List[Generation] = [] | |
stopped = True | |
# Zipped iterator | |
iterator = zip( | |
batch.requests, | |
batch.input_lengths, | |
batch.offsets, | |
batch.token_offsets, | |
batch.next_token_choosers, | |
batch.stopping_criterias, | |
batch.all_input_ids, | |
batch.all_input_ids_tensor, | |
) | |
# For each member of the batch | |
for i, ( | |
request, | |
input_length, | |
offset, | |
token_offset, | |
next_token_chooser, | |
stopping_criteria, | |
all_input_ids, | |
all_input_ids_tensor, | |
) in enumerate(iterator): | |
# Indexing metadata | |
start_index = cumulative_length | |
end_index = cumulative_length + input_length | |
prefill = stopping_criteria.current_tokens == 0 | |
if prefill: | |
# Prefill mode | |
# out is of shape [cumulative_sequence_lengths, vocab_size] | |
logits = out[start_index:end_index] | |
else: | |
# Decode mode | |
# out is of shape [batch_size, vocab_size] | |
logits = out[i].unsqueeze(0) | |
# Select next token | |
next_token_id, logprobs = next_token_chooser( | |
all_input_ids_tensor[None, :input_length], logits | |
) | |
next_token_id_squeezed = next_token_id.squeeze() | |
next_token_id_item = next_token_id_squeezed.item() | |
# Append next token to all tokens | |
all_input_ids.append(next_token_id_item) | |
all_input_ids_tensor[input_length] = next_token_id_item | |
# Generated token | |
next_token_logprob = logprobs[-1, next_token_id_item] | |
next_token_text, offset, token_offset = self.decode_token( | |
all_input_ids, | |
offset, | |
token_offset, | |
) | |
# Evaluate stopping criteria | |
stop, reason = stopping_criteria( | |
next_token_id_item, | |
next_token_text, | |
) | |
if stop: | |
# Decode generated tokens | |
output_text = self.decode( | |
all_input_ids[-stopping_criteria.current_tokens :] | |
) | |
# Get seed | |
if isinstance(next_token_chooser.choice, Sampling): | |
seed = next_token_chooser.choice.seed | |
else: | |
seed = None | |
generated_text = GeneratedText( | |
output_text, stopping_criteria.current_tokens, reason, seed | |
) | |
else: | |
stopped = False | |
generated_text = None | |
# Prefill | |
if prefill: | |
# Remove generated token to only have prefill and add nan for first prompt token | |
prefill_logprobs = [float("nan")] + logprobs.gather( | |
1, all_input_ids_tensor[1:input_length].unsqueeze(1) | |
).squeeze(1)[:-1].tolist() | |
prefill_token_ids = all_input_ids[:-1] | |
prefill_texts = self.tokenizer.batch_decode( | |
prefill_token_ids, | |
clean_up_tokenization_spaces=False, | |
skip_special_tokens=False, | |
) | |
prefill_tokens = PrefillTokens( | |
prefill_token_ids, prefill_logprobs, prefill_texts | |
) | |
else: | |
prefill_tokens = None | |
generation = Generation( | |
request.id, | |
prefill_tokens, | |
next_token_id_item, | |
next_token_logprob, | |
next_token_text, | |
next_token_id_item in self.all_special_ids, | |
generated_text, | |
) | |
generations.append(generation) | |
cumulative_length += input_length | |
new_input_length = input_length + 1 | |
# Update values | |
batch.input_ids[i] = next_token_id | |
batch.position_ids[i] = input_length | |
batch.input_lengths[i] = new_input_length | |
batch.offsets[i] = offset | |
batch.token_offsets[i] = token_offset | |
batch.all_input_ids[i] = all_input_ids | |
batch.all_input_ids_tensor[i] = all_input_ids_tensor | |
batch.max_seqlen = max(batch.max_seqlen, new_input_length) | |
if len(batch) != 1: | |
# Add each sequence before its padding | |
batch.past_key_values[i * 2] = present[:, start_index:end_index] | |
# Cumulative sum | |
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length | |
# No need to return a batch if we know that all requests stopped | |
return generations, batch if not stopped else None | |