Devops-hestabit's picture
Upload folder using huggingface_hub
61e6a6c verified
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections.abc import Generator
from dataclasses import dataclass
from typing import Optional
import numpy as np
class RequestValidationError(Exception):
pass
def _validate_that(condition: bool, msg: str):
if not condition:
raise RequestValidationError(msg)
def _validate_non_empty(data, msg: str):
_validate_that(data is not None and data.size > 0, msg)
def _validate_single_gt_0(data, msg: str):
_validate_non_empty(data, msg)
_validate_that(data.flatten()[0] > 0, msg)
def _single_value(data: Optional[np.ndarray]):
if data is None:
return None
return data.flatten()[0]
@dataclass
class Request:
text_input: np.ndarray = np.array([])
decoder_text_input: np.ndarray = None
max_tokens: np.ndarray = np.array([])
bad_words: Optional[np.ndarray] = None
stop_words: Optional[np.ndarray] = None
end_id: Optional[np.ndarray] = None
pad_id: Optional[np.ndarray] = None
top_k: Optional[np.ndarray] = None
top_p: Optional[np.ndarray] = None
temperature: Optional[np.ndarray] = None
length_penalty: Optional[np.ndarray] = None
repetition_penalty: Optional[np.ndarray] = None
min_length: Optional[np.ndarray] = None
return_log_probs: Optional[np.ndarray] = None
prompt_embedding_table: Optional[np.ndarray] = None
prompt_vocab_size: Optional[np.ndarray] = None
embedding_bias_words: Optional[np.ndarray] = None
embedding_bias_weights: Optional[np.ndarray] = None
num_draft_tokens: Optional[np.ndarray] = None
use_draft_logits: Optional[np.ndarray] = None
stream: Optional[np.ndarray] = None
beam_width: Optional[np.ndarray] = None
return_context_logits: Optional[np.ndarray] = None
return_generation_logits: Optional[np.ndarray] = None
random_seed: Optional[np.ndarray] = None
presence_penalty: Optional[np.ndarray] = None
frequency_penalty: Optional[np.ndarray] = None
def validate(self):
_validate_non_empty(self.text_input, "text_input is required")
_validate_single_gt_0(self.max_tokens,
"max_tokens must be a single value > 0")
num_draft_tokens = _single_value(self.num_draft_tokens)
stream = _single_value(self.stream)
_single_value(self.return_generation_logits)
context_logits = _single_value(self.return_context_logits)
if num_draft_tokens:
_validate_that(
not stream,
"streaming is not supported with speculative decoding")
_validate_that(
not context_logits,
"context logits are not supported with speculative decoding")
@dataclass
class DraftRequest:
draft_input_ids: Optional[np.ndarray] = None
draft_logits: Optional[np.ndarray] = None
@dataclass
class PreprocResponse:
input_ids: np.ndarray = np.array([])
decoder_input_ids: np.ndarray = None
input_lengths: np.ndarray = np.array([])
decoder_input_lengths: np.ndarray = None
bad_words_list: Optional[np.ndarray] = None
stop_words_list: Optional[np.ndarray] = None
embedding_bias: Optional[np.ndarray] = None
end_id: Optional[np.ndarray] = None
pad_id: Optional[np.ndarray] = None
@classmethod
def with_new_inputs(cls,
other,
input_ids: Optional[np.ndarray] = None,
input_lengths: Optional[np.ndarray] = None):
return cls(
input_ids=(input_ids
if input_ids is not None else other.input_ids),
input_lengths=(input_lengths if input_lengths is not None else
other.input_lengths),
decoder_input_ids=other.decoder_input_ids,
decoder_input_lengths=other.decoder_input_lengths,
bad_words_list=other.bad_words_list,
stop_words_list=other.stop_words_list,
end_id=other.end_id,
pad_id=other.pad_id,
)
@dataclass
class GenerationResponse:
output_ids: np.ndarray = np.array([])
sequence_length: np.ndarray = np.array([])
cum_log_probs: Optional[np.ndarray] = None
output_log_probs: Optional[np.ndarray] = None
context_logits: Optional[np.ndarray] = None
generation_logits: Optional[np.ndarray] = None
@dataclass
class Response:
text_output: np.ndarray = np.array([])
cum_log_probs: Optional[np.ndarray] = None
output_log_probs: Optional[np.ndarray] = None
context_logits: Optional[np.ndarray] = None
generation_logits: Optional[np.ndarray] = None
def __eq__(self, o) -> bool:
"""Just for testing"""
if not isinstance(o, Response):
return False
return (np.array_equal(self.text_output, o.text_output)
and np.array_equal(self.cum_log_probs, o.cum_log_probs)
and np.array_equal(self.output_log_probs, o.output_log_probs)
and np.array_equal(self.context_logits, o.context_logits) and
np.array_equal(self.generation_logits, o.generation_logits))
class Decoder:
def __init__(self, streaming=False, accumulate=False):
self._streaming = streaming
self._accumulate = accumulate
self._accumulated_tokens = None
def decode(self,
request: Request,
speculative_decoding=False) -> Generator[Response, None, None]:
preproc_response = self.preprocess(request)
if speculative_decoding:
for gen_response in self._spec_generate(preproc_response, request):
yield self.postprocess(gen_response)
else:
if not self._streaming:
gen_response = self._generate_non_streaming(
preproc_response, request)
yield self.postprocess(gen_response)
else:
for gen_response in self._generate(preproc_response, request):
yield self.postprocess(gen_response)
def encountered_stop_words(self, input_ids, stop_words_ids):
for stop_word_ids in stop_words_ids:
if np.array_equal(input_ids[-len(stop_word_ids):], stop_word_ids):
return True
return False
def _spec_generate(
self, preproc: PreprocResponse,
request: Request) -> Generator[GenerationResponse, None, None]:
prompt_input_ids: np.ndarray = preproc.input_ids[0]
input_ids: np.ndarray = prompt_input_ids
output_len: int = request.max_tokens[0][0]
last_input_ids: np.ndarray = None
draft_output_ids: np.ndarray = None
draft_logits: np.ndarray = None
target_response: GenerationResponse = None
cur_preproc = preproc
counter = 0
while True:
counter += 1
num_draft_tokens = min(
request.num_draft_tokens[0][0],
len(prompt_input_ids) + output_len - len(input_ids) - 1)
draft_request = None
if num_draft_tokens > 0:
draft_response: GenerationResponse = self._draft_generate_non_streaming(
cur_preproc, request, num_draft_tokens)
seq_len: int = draft_response.sequence_length[0][0]
# [1, beamWidth, outputLength] -> [outputLen]
draft_output_ids = draft_response.output_ids[0][0]
# [1, beamWidth, outputLength, vocabSizePadded] -> [outputLength, vocabSizePadded]
if request.use_draft_logits is not None and request.use_draft_logits[
0]:
if draft_response.generation_logits is not None:
draft_logits = draft_response.generation_logits[0][0]
input_draft_tokens = draft_output_ids[len(input_ids):seq_len]
draft_request = DraftRequest(
draft_input_ids=np.expand_dims(input_draft_tokens, 0))
if request.use_draft_logits is not None and request.use_draft_logits[
0]:
draft_request.draft_logits = np.expand_dims(
draft_logits[-len(input_draft_tokens):], 0)
else:
draft_request = DraftRequest()
target_response = self._generate_non_streaming(
cur_preproc, request, draft_request)
last_input_ids = input_ids
input_ids = target_response.output_ids[0][0]
cur_preproc = PreprocResponse.with_new_inputs(
cur_preproc, np.expand_dims(input_ids, 0),
np.array([[len(input_ids)]], dtype=np.int32))
# Evaluate criteria to stop generation loop.
# If we've hit or exceeded the max output length, should stop
length_stop = (len(input_ids) >=
len(prompt_input_ids) + output_len)
if length_stop:
break
# If draft and target have same outputs, should stop. Normally target should return 1 more token.
# If they are the same length, they should differ at the last token
target_draft_equal = draft_output_ids is not None and np.array_equal(
draft_output_ids, input_ids)
if target_draft_equal:
break
# If tokens no longer change, should stop, means we have hit early stopping
last_current_equal = np.array_equal(last_input_ids, input_ids)
if last_current_equal:
break
# Need to check if stop words was encountered
hit_stop_words = self.encountered_stop_words(
input_ids, preproc.stop_words_list[0])
if hit_stop_words:
break
yield target_response
def _draft_generate_non_streaming(
self, preproc: PreprocResponse, request: Request,
num_draft_tokens: int) -> GenerationResponse:
raise NotImplementedError()
def _generate(
self,
preproc: PreprocResponse,
request: Request,
draft_request: Optional[DraftRequest] = None
) -> Generator[GenerationResponse, None, None]:
raise NotImplementedError()
def _generate_non_streaming(
self,
preproc: PreprocResponse,
request: Request,
draft_request: Optional[DraftRequest] = None
) -> GenerationResponse:
raise NotImplementedError()
def postprocess(self, gen_response: GenerationResponse) -> Response:
if self._accumulate and self._streaming:
new_tokens: np.ndarray = gen_response.output_ids
if new_tokens.ndim != 3:
raise Exception("Expected output_ids tensor to have 3 dims.")
if new_tokens.shape[0] != 1:
raise Exception("Expected batch size of 1")
if new_tokens.shape[1] != 1:
raise Exception(
"Accumulation of tokens is only implemented for beam width = 1"
)
self._accumulated_tokens = new_tokens if (
self._accumulated_tokens is None) else np.concatenate(
(self._accumulated_tokens, new_tokens), axis=2)
sequence_lengths = np.array([[self._accumulated_tokens.shape[2]]],
dtype=np.int32)
return self._postprocess(self._accumulated_tokens,
sequence_lengths, gen_response)
else:
return self._postprocess(gen_response.output_ids, None,
gen_response)
def _postprocess(self, tokens: np.ndarray,
sequence_lengths: Optional[np.ndarray],
gen_response: GenerationResponse) -> Response:
raise NotImplementedError()
def preprocess(self, request: Request) -> PreprocResponse:
raise NotImplementedError()
def reset_decoder(self):
self._accumulated_tokens = None