|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections.abc import Generator |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import numpy as np |
|
|
|
|
|
class RequestValidationError(Exception): |
|
pass |
|
|
|
|
|
def _validate_that(condition: bool, msg: str): |
|
if not condition: |
|
raise RequestValidationError(msg) |
|
|
|
|
|
def _validate_non_empty(data, msg: str): |
|
_validate_that(data is not None and data.size > 0, msg) |
|
|
|
|
|
def _validate_single_gt_0(data, msg: str): |
|
_validate_non_empty(data, msg) |
|
_validate_that(data.flatten()[0] > 0, msg) |
|
|
|
|
|
def _single_value(data: Optional[np.ndarray]): |
|
if data is None: |
|
return None |
|
return data.flatten()[0] |
|
|
|
|
|
@dataclass |
|
class Request: |
|
text_input: np.ndarray = np.array([]) |
|
decoder_text_input: np.ndarray = None |
|
max_tokens: np.ndarray = np.array([]) |
|
bad_words: Optional[np.ndarray] = None |
|
stop_words: Optional[np.ndarray] = None |
|
end_id: Optional[np.ndarray] = None |
|
pad_id: Optional[np.ndarray] = None |
|
top_k: Optional[np.ndarray] = None |
|
top_p: Optional[np.ndarray] = None |
|
temperature: Optional[np.ndarray] = None |
|
length_penalty: Optional[np.ndarray] = None |
|
repetition_penalty: Optional[np.ndarray] = None |
|
min_length: Optional[np.ndarray] = None |
|
return_log_probs: Optional[np.ndarray] = None |
|
prompt_embedding_table: Optional[np.ndarray] = None |
|
prompt_vocab_size: Optional[np.ndarray] = None |
|
embedding_bias_words: Optional[np.ndarray] = None |
|
embedding_bias_weights: Optional[np.ndarray] = None |
|
num_draft_tokens: Optional[np.ndarray] = None |
|
use_draft_logits: Optional[np.ndarray] = None |
|
stream: Optional[np.ndarray] = None |
|
beam_width: Optional[np.ndarray] = None |
|
return_context_logits: Optional[np.ndarray] = None |
|
return_generation_logits: Optional[np.ndarray] = None |
|
random_seed: Optional[np.ndarray] = None |
|
presence_penalty: Optional[np.ndarray] = None |
|
frequency_penalty: Optional[np.ndarray] = None |
|
|
|
def validate(self): |
|
_validate_non_empty(self.text_input, "text_input is required") |
|
_validate_single_gt_0(self.max_tokens, |
|
"max_tokens must be a single value > 0") |
|
|
|
num_draft_tokens = _single_value(self.num_draft_tokens) |
|
stream = _single_value(self.stream) |
|
_single_value(self.return_generation_logits) |
|
context_logits = _single_value(self.return_context_logits) |
|
|
|
if num_draft_tokens: |
|
_validate_that( |
|
not stream, |
|
"streaming is not supported with speculative decoding") |
|
_validate_that( |
|
not context_logits, |
|
"context logits are not supported with speculative decoding") |
|
|
|
|
|
@dataclass |
|
class DraftRequest: |
|
draft_input_ids: Optional[np.ndarray] = None |
|
draft_logits: Optional[np.ndarray] = None |
|
|
|
|
|
@dataclass |
|
class PreprocResponse: |
|
input_ids: np.ndarray = np.array([]) |
|
decoder_input_ids: np.ndarray = None |
|
input_lengths: np.ndarray = np.array([]) |
|
decoder_input_lengths: np.ndarray = None |
|
bad_words_list: Optional[np.ndarray] = None |
|
stop_words_list: Optional[np.ndarray] = None |
|
embedding_bias: Optional[np.ndarray] = None |
|
end_id: Optional[np.ndarray] = None |
|
pad_id: Optional[np.ndarray] = None |
|
|
|
@classmethod |
|
def with_new_inputs(cls, |
|
other, |
|
input_ids: Optional[np.ndarray] = None, |
|
input_lengths: Optional[np.ndarray] = None): |
|
return cls( |
|
input_ids=(input_ids |
|
if input_ids is not None else other.input_ids), |
|
input_lengths=(input_lengths if input_lengths is not None else |
|
other.input_lengths), |
|
decoder_input_ids=other.decoder_input_ids, |
|
decoder_input_lengths=other.decoder_input_lengths, |
|
bad_words_list=other.bad_words_list, |
|
stop_words_list=other.stop_words_list, |
|
end_id=other.end_id, |
|
pad_id=other.pad_id, |
|
) |
|
|
|
|
|
@dataclass |
|
class GenerationResponse: |
|
output_ids: np.ndarray = np.array([]) |
|
sequence_length: np.ndarray = np.array([]) |
|
cum_log_probs: Optional[np.ndarray] = None |
|
output_log_probs: Optional[np.ndarray] = None |
|
context_logits: Optional[np.ndarray] = None |
|
generation_logits: Optional[np.ndarray] = None |
|
|
|
|
|
@dataclass |
|
class Response: |
|
text_output: np.ndarray = np.array([]) |
|
cum_log_probs: Optional[np.ndarray] = None |
|
output_log_probs: Optional[np.ndarray] = None |
|
context_logits: Optional[np.ndarray] = None |
|
generation_logits: Optional[np.ndarray] = None |
|
|
|
def __eq__(self, o) -> bool: |
|
"""Just for testing""" |
|
if not isinstance(o, Response): |
|
return False |
|
return (np.array_equal(self.text_output, o.text_output) |
|
and np.array_equal(self.cum_log_probs, o.cum_log_probs) |
|
and np.array_equal(self.output_log_probs, o.output_log_probs) |
|
and np.array_equal(self.context_logits, o.context_logits) and |
|
np.array_equal(self.generation_logits, o.generation_logits)) |
|
|
|
|
|
class Decoder: |
|
|
|
def __init__(self, streaming=False, accumulate=False): |
|
self._streaming = streaming |
|
self._accumulate = accumulate |
|
|
|
self._accumulated_tokens = None |
|
|
|
def decode(self, |
|
request: Request, |
|
speculative_decoding=False) -> Generator[Response, None, None]: |
|
preproc_response = self.preprocess(request) |
|
|
|
if speculative_decoding: |
|
for gen_response in self._spec_generate(preproc_response, request): |
|
yield self.postprocess(gen_response) |
|
else: |
|
if not self._streaming: |
|
gen_response = self._generate_non_streaming( |
|
preproc_response, request) |
|
yield self.postprocess(gen_response) |
|
else: |
|
for gen_response in self._generate(preproc_response, request): |
|
yield self.postprocess(gen_response) |
|
|
|
def encountered_stop_words(self, input_ids, stop_words_ids): |
|
for stop_word_ids in stop_words_ids: |
|
if np.array_equal(input_ids[-len(stop_word_ids):], stop_word_ids): |
|
return True |
|
return False |
|
|
|
def _spec_generate( |
|
self, preproc: PreprocResponse, |
|
request: Request) -> Generator[GenerationResponse, None, None]: |
|
|
|
prompt_input_ids: np.ndarray = preproc.input_ids[0] |
|
input_ids: np.ndarray = prompt_input_ids |
|
output_len: int = request.max_tokens[0][0] |
|
last_input_ids: np.ndarray = None |
|
draft_output_ids: np.ndarray = None |
|
draft_logits: np.ndarray = None |
|
|
|
target_response: GenerationResponse = None |
|
|
|
cur_preproc = preproc |
|
|
|
counter = 0 |
|
while True: |
|
counter += 1 |
|
num_draft_tokens = min( |
|
request.num_draft_tokens[0][0], |
|
len(prompt_input_ids) + output_len - len(input_ids) - 1) |
|
|
|
draft_request = None |
|
if num_draft_tokens > 0: |
|
draft_response: GenerationResponse = self._draft_generate_non_streaming( |
|
cur_preproc, request, num_draft_tokens) |
|
seq_len: int = draft_response.sequence_length[0][0] |
|
|
|
draft_output_ids = draft_response.output_ids[0][0] |
|
|
|
if request.use_draft_logits is not None and request.use_draft_logits[ |
|
0]: |
|
if draft_response.generation_logits is not None: |
|
draft_logits = draft_response.generation_logits[0][0] |
|
|
|
input_draft_tokens = draft_output_ids[len(input_ids):seq_len] |
|
draft_request = DraftRequest( |
|
draft_input_ids=np.expand_dims(input_draft_tokens, 0)) |
|
if request.use_draft_logits is not None and request.use_draft_logits[ |
|
0]: |
|
draft_request.draft_logits = np.expand_dims( |
|
draft_logits[-len(input_draft_tokens):], 0) |
|
else: |
|
draft_request = DraftRequest() |
|
target_response = self._generate_non_streaming( |
|
cur_preproc, request, draft_request) |
|
last_input_ids = input_ids |
|
input_ids = target_response.output_ids[0][0] |
|
cur_preproc = PreprocResponse.with_new_inputs( |
|
cur_preproc, np.expand_dims(input_ids, 0), |
|
np.array([[len(input_ids)]], dtype=np.int32)) |
|
|
|
|
|
|
|
length_stop = (len(input_ids) >= |
|
len(prompt_input_ids) + output_len) |
|
if length_stop: |
|
break |
|
|
|
|
|
target_draft_equal = draft_output_ids is not None and np.array_equal( |
|
draft_output_ids, input_ids) |
|
if target_draft_equal: |
|
break |
|
|
|
last_current_equal = np.array_equal(last_input_ids, input_ids) |
|
if last_current_equal: |
|
break |
|
|
|
hit_stop_words = self.encountered_stop_words( |
|
input_ids, preproc.stop_words_list[0]) |
|
if hit_stop_words: |
|
break |
|
|
|
yield target_response |
|
|
|
def _draft_generate_non_streaming( |
|
self, preproc: PreprocResponse, request: Request, |
|
num_draft_tokens: int) -> GenerationResponse: |
|
raise NotImplementedError() |
|
|
|
def _generate( |
|
self, |
|
preproc: PreprocResponse, |
|
request: Request, |
|
draft_request: Optional[DraftRequest] = None |
|
) -> Generator[GenerationResponse, None, None]: |
|
raise NotImplementedError() |
|
|
|
def _generate_non_streaming( |
|
self, |
|
preproc: PreprocResponse, |
|
request: Request, |
|
draft_request: Optional[DraftRequest] = None |
|
) -> GenerationResponse: |
|
raise NotImplementedError() |
|
|
|
def postprocess(self, gen_response: GenerationResponse) -> Response: |
|
if self._accumulate and self._streaming: |
|
new_tokens: np.ndarray = gen_response.output_ids |
|
if new_tokens.ndim != 3: |
|
raise Exception("Expected output_ids tensor to have 3 dims.") |
|
if new_tokens.shape[0] != 1: |
|
raise Exception("Expected batch size of 1") |
|
if new_tokens.shape[1] != 1: |
|
raise Exception( |
|
"Accumulation of tokens is only implemented for beam width = 1" |
|
) |
|
|
|
self._accumulated_tokens = new_tokens if ( |
|
self._accumulated_tokens is None) else np.concatenate( |
|
(self._accumulated_tokens, new_tokens), axis=2) |
|
sequence_lengths = np.array([[self._accumulated_tokens.shape[2]]], |
|
dtype=np.int32) |
|
return self._postprocess(self._accumulated_tokens, |
|
sequence_lengths, gen_response) |
|
else: |
|
return self._postprocess(gen_response.output_ids, None, |
|
gen_response) |
|
|
|
def _postprocess(self, tokens: np.ndarray, |
|
sequence_lengths: Optional[np.ndarray], |
|
gen_response: GenerationResponse) -> Response: |
|
raise NotImplementedError() |
|
|
|
def preprocess(self, request: Request) -> PreprocResponse: |
|
raise NotImplementedError() |
|
|
|
def reset_decoder(self): |
|
self._accumulated_tokens = None |
|
|