|
|
|
|
|
|
|
"""Implements a OpenAI chat and causal LM inference API wrappers.""" |
|
|
|
import logging |
|
import os |
|
from time import sleep |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import torch |
|
from composer.core.types import Batch |
|
from composer.utils.import_helpers import MissingConditionalImportError |
|
from transformers import AutoTokenizer |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
from llmfoundry.models.inference_api_wrapper.interface import \ |
|
InferenceAPIEvalWrapper |
|
|
|
__all__ = [ |
|
'OpenAICausalLMEvalWrapper', |
|
'OpenAIChatAPIEvalWrapper', |
|
] |
|
|
|
MAX_RETRIES = 10 |
|
|
|
|
|
class OpenAIEvalInterface(InferenceAPIEvalWrapper): |
|
|
|
def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: |
|
super().__init__(model_cfg, tokenizer) |
|
try: |
|
import openai |
|
except ImportError as e: |
|
raise MissingConditionalImportError( |
|
extra_deps_group='openai', |
|
conda_package='openai', |
|
conda_channel='conda-forge') from e |
|
openai.api_key = os.getenv('OPENAI_API_KEY') |
|
self.model_name = model_cfg['version'] |
|
|
|
def generate_completion(self, prompt: str, num_tokens: int): |
|
raise NotImplementedError() |
|
|
|
def process_result(self, completion: Optional[dict]): |
|
raise NotImplementedError() |
|
|
|
def get_next_token_logit_tensor(self, prompt: str, num_tokens: int = 1): |
|
completion = self.try_generate_completion(prompt, num_tokens) |
|
return self.process_result(completion) |
|
|
|
def try_generate_completion(self, prompt: str, num_tokens: int): |
|
try: |
|
from openai.error import RateLimitError |
|
except ImportError as e: |
|
raise MissingConditionalImportError( |
|
extra_deps_group='openai', |
|
conda_package='openai', |
|
conda_channel='conda-forge') from e |
|
tries = 0 |
|
completion = None |
|
while tries < MAX_RETRIES: |
|
tries += 1 |
|
try: |
|
|
|
completion = self.generate_completion(prompt, num_tokens) |
|
break |
|
except RateLimitError as e: |
|
if 'You exceeded your current quota' in str(e._message): |
|
raise e |
|
sleep(60) |
|
continue |
|
except Exception: |
|
continue |
|
return completion |
|
|
|
|
|
class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface): |
|
|
|
def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: |
|
super().__init__(model_cfg, tokenizer) |
|
try: |
|
import openai |
|
except ImportError as e: |
|
raise MissingConditionalImportError( |
|
extra_deps_group='openai', |
|
conda_package='openai', |
|
conda_channel='conda-forge') from e |
|
|
|
self.generate_completion = lambda prompt, num_tokens: openai.ChatCompletion.create( |
|
self.model_name, |
|
messages=[{ |
|
'role': 'user', |
|
'content': prompt |
|
}], |
|
max_tokens=num_tokens, |
|
temperature=0.0) |
|
|
|
def retokenize(self, tokens: List[int], cont_idxs: List[int]): |
|
"""Chat API will never respond with a word-initial space. |
|
|
|
If the continuation tokens begin with a word initial space, we need to |
|
re-tokenize with the space removed. |
|
""" |
|
original_len = len(tokens) |
|
retokenized_continuation = self.tokenizer( |
|
self.tokenizer.decode(tokens[cont_idxs[0]:cont_idxs[-1] + |
|
1]).strip())['input_ids'] |
|
|
|
|
|
padding = [tokens[-1]] * ( |
|
len(tokens) - len(tokens[:cont_idxs[0]] + retokenized_continuation)) |
|
tokens = tokens[:cont_idxs[0]] + retokenized_continuation + padding |
|
|
|
if len(tokens) > original_len: |
|
|
|
tokens = tokens[-original_len:] |
|
cont_idxs = list( |
|
range(original_len - len(retokenized_continuation), |
|
original_len)) |
|
else: |
|
cont_idxs = list( |
|
range(cont_idxs[0], |
|
cont_idxs[0] + len(retokenized_continuation))) |
|
return torch.tensor(tokens), torch.tensor(cont_idxs) |
|
|
|
def rebatch(self, batch: Batch): |
|
"""Chat API tokenization has different behavior than GPT3. |
|
|
|
Model responses will never begin with spaces even if the continuation is |
|
expected to, so we need to retokenize the input to account for that. |
|
""" |
|
new_batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = { |
|
'input_ids': [], |
|
'continuation_indices': [], |
|
'labels': [] |
|
} |
|
for tokens, cont_idxs in zip(batch['input_ids'], |
|
batch['continuation_indices']): |
|
tokens, cont_idxs = self.retokenize(tokens.tolist(), |
|
cont_idxs.tolist()) |
|
|
|
assert isinstance(new_batch['input_ids'], list) |
|
new_batch['input_ids'].append(tokens) |
|
assert isinstance(new_batch['labels'], list) |
|
new_batch['labels'].append(tokens) |
|
assert isinstance(new_batch['continuation_indices'], list) |
|
new_batch['continuation_indices'].append(cont_idxs) |
|
|
|
new_batch.update({ |
|
k: torch.stack(new_batch[k]) |
|
for k in ['input_ids', 'labels'] |
|
}) |
|
|
|
new_batch.update({k: v for k, v in batch.items() if k not in new_batch}) |
|
|
|
return new_batch |
|
|
|
def eval_forward(self, batch: Batch, outputs: Optional[Any] = None): |
|
|
|
|
|
|
|
|
|
output_logits_batch = [] |
|
batch = self.rebatch(batch) |
|
for tokens, cont_idxs in zip(batch['input_ids'], |
|
batch['continuation_indices']): |
|
|
|
seqlen = tokens.shape[0] |
|
tokens = tokens.tolist() |
|
cont_idxs = cont_idxs.tolist() |
|
expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1] |
|
output_logits = torch.nn.functional.one_hot( |
|
torch.tensor(tokens[1:cont_idxs[0]]), |
|
num_classes=self.tokenizer.vocab_size) |
|
|
|
prompt = self.tokenizer.decode(tokens[:cont_idxs[0]]) |
|
next_logit_tensor = self.get_next_token_logit_tensor( |
|
prompt, num_tokens=len(expected_cont_tokens)) |
|
|
|
if next_logit_tensor is not None: |
|
output_logits = torch.cat([output_logits, next_logit_tensor]) |
|
padding = torch.nn.functional.one_hot( |
|
torch.full((seqlen - output_logits.shape[0],), |
|
self.tokenizer.pad_token_id), |
|
num_classes=self.tokenizer.vocab_size) |
|
output_logits = torch.cat([output_logits, padding]) |
|
output_logits_batch.append(output_logits) |
|
|
|
return torch.stack(output_logits_batch).to(batch['input_ids'].device) |
|
|
|
def process_result(self, completion: Optional[dict]): |
|
assert isinstance(completion, dict) |
|
if len(completion['choices']) > 0: |
|
tensors = [] |
|
for t in self.tokenizer(completion['choices'][0]['message'] |
|
['content'])['input_ids']: |
|
tensors.append( |
|
self.tokenizer.construct_logit_tensor( |
|
{self.tokenizer.decode([t]): 0.0})) |
|
|
|
if len(tensors) == 0: |
|
return None |
|
return torch.stack(tensors) |
|
else: |
|
|
|
|
|
return None |
|
|
|
|
|
class OpenAICausalLMEvalWrapper(OpenAIEvalInterface): |
|
|
|
def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: |
|
super().__init__(model_cfg, tokenizer) |
|
try: |
|
import openai |
|
except ImportError as e: |
|
raise MissingConditionalImportError( |
|
extra_deps_group='openai', |
|
conda_package='openai', |
|
conda_channel='conda-forge') from e |
|
|
|
self.generate_completion = lambda prompt, num_tokens: openai.Completion.create( |
|
engine=self.model_name, |
|
prompt=prompt, |
|
max_tokens=1, |
|
logprobs=5, |
|
temperature=0.0) |
|
|
|
def process_result(self, completion: Optional[dict]): |
|
if completion is None: |
|
raise ValueError("Couldn't generate model output") |
|
|
|
assert isinstance(completion, dict) |
|
if len(completion['choices'][0]['logprobs']['top_logprobs']) > 0: |
|
tensor = self.tokenizer.construct_logit_tensor( |
|
dict(completion['choices'][0]['logprobs']['top_logprobs'][0])) |
|
return tensor |
|
else: |
|
|
|
|
|
return None |
|
|