# Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 """Implements a OpenAI chat and causal LM inference API wrappers.""" import logging import os from time import sleep from typing import Any, Dict, List, Optional, Union import torch from composer.core.types import Batch from composer.utils.import_helpers import MissingConditionalImportError from transformers import AutoTokenizer log = logging.getLogger(__name__) from llmfoundry.models.inference_api_wrapper.interface import \ InferenceAPIEvalWrapper __all__ = [ 'OpenAICausalLMEvalWrapper', 'OpenAIChatAPIEvalWrapper', ] MAX_RETRIES = 10 class OpenAIEvalInterface(InferenceAPIEvalWrapper): def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: super().__init__(model_cfg, tokenizer) try: import openai except ImportError as e: raise MissingConditionalImportError( extra_deps_group='openai', conda_package='openai', conda_channel='conda-forge') from e openai.api_key = os.getenv('OPENAI_API_KEY') self.model_name = model_cfg['version'] def generate_completion(self, prompt: str, num_tokens: int): raise NotImplementedError() def process_result(self, completion: Optional[dict]): raise NotImplementedError() def get_next_token_logit_tensor(self, prompt: str, num_tokens: int = 1): completion = self.try_generate_completion(prompt, num_tokens) return self.process_result(completion) def try_generate_completion(self, prompt: str, num_tokens: int): try: from openai.error import RateLimitError except ImportError as e: raise MissingConditionalImportError( extra_deps_group='openai', conda_package='openai', conda_channel='conda-forge') from e tries = 0 completion = None while tries < MAX_RETRIES: tries += 1 try: completion = self.generate_completion(prompt, num_tokens) break except RateLimitError as e: if 'You exceeded your current quota' in str(e._message): raise e sleep(60) continue except Exception: continue return completion class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface): def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: super().__init__(model_cfg, tokenizer) try: import openai except ImportError as e: raise MissingConditionalImportError( extra_deps_group='openai', conda_package='openai', conda_channel='conda-forge') from e self.generate_completion = lambda prompt, num_tokens: openai.ChatCompletion.create( self.model_name, messages=[{ 'role': 'user', 'content': prompt }], max_tokens=num_tokens, temperature=0.0) def retokenize(self, tokens: List[int], cont_idxs: List[int]): """Chat API will never respond with a word-initial space. If the continuation tokens begin with a word initial space, we need to re-tokenize with the space removed. """ original_len = len(tokens) retokenized_continuation = self.tokenizer( self.tokenizer.decode(tokens[cont_idxs[0]:cont_idxs[-1] + 1]).strip())['input_ids'] # replace the original continuation with the retokenized continuation + padding padding = [tokens[-1]] * ( len(tokens) - len(tokens[:cont_idxs[0]] + retokenized_continuation)) tokens = tokens[:cont_idxs[0]] + retokenized_continuation + padding if len(tokens) > original_len: # this only happens if we were already at max seq len and the continuation got LARGER tokens = tokens[-original_len:] cont_idxs = list( range(original_len - len(retokenized_continuation), original_len)) else: cont_idxs = list( range(cont_idxs[0], cont_idxs[0] + len(retokenized_continuation))) return torch.tensor(tokens), torch.tensor(cont_idxs) def rebatch(self, batch: Batch): """Chat API tokenization has different behavior than GPT3. Model responses will never begin with spaces even if the continuation is expected to, so we need to retokenize the input to account for that. """ new_batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = { 'input_ids': [], 'continuation_indices': [], 'labels': [] } for tokens, cont_idxs in zip(batch['input_ids'], batch['continuation_indices']): tokens, cont_idxs = self.retokenize(tokens.tolist(), cont_idxs.tolist()) assert isinstance(new_batch['input_ids'], list) new_batch['input_ids'].append(tokens) assert isinstance(new_batch['labels'], list) new_batch['labels'].append(tokens) assert isinstance(new_batch['continuation_indices'], list) new_batch['continuation_indices'].append(cont_idxs) new_batch.update({ k: torch.stack(new_batch[k]) # pyright: ignore for k in ['input_ids', 'labels'] }) new_batch.update({k: v for k, v in batch.items() if k not in new_batch}) return new_batch def eval_forward(self, batch: Batch, outputs: Optional[Any] = None): # Override the base class because Chat's API always strips spacing from model outputs resulting in different tokens # than what the continuation would expect. # Get around this issue by retokenizing the batch to remove spacing from the continuation as well as # decoding the whole continuation at once. output_logits_batch = [] batch = self.rebatch(batch) for tokens, cont_idxs in zip(batch['input_ids'], batch['continuation_indices']): seqlen = tokens.shape[0] tokens = tokens.tolist() cont_idxs = cont_idxs.tolist() expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1] output_logits = torch.nn.functional.one_hot( torch.tensor(tokens[1:cont_idxs[0]]), num_classes=self.tokenizer.vocab_size) prompt = self.tokenizer.decode(tokens[:cont_idxs[0]]) next_logit_tensor = self.get_next_token_logit_tensor( prompt, num_tokens=len(expected_cont_tokens)) if next_logit_tensor is not None: output_logits = torch.cat([output_logits, next_logit_tensor]) padding = torch.nn.functional.one_hot( torch.full((seqlen - output_logits.shape[0],), self.tokenizer.pad_token_id), num_classes=self.tokenizer.vocab_size) output_logits = torch.cat([output_logits, padding]) output_logits_batch.append(output_logits) return torch.stack(output_logits_batch).to(batch['input_ids'].device) def process_result(self, completion: Optional[dict]): assert isinstance(completion, dict) if len(completion['choices']) > 0: tensors = [] for t in self.tokenizer(completion['choices'][0]['message'] ['content'])['input_ids']: tensors.append( self.tokenizer.construct_logit_tensor( {self.tokenizer.decode([t]): 0.0})) if len(tensors) == 0: return None return torch.stack(tensors) else: # the model sometimes stops early even though we are still requesting tokens! # not sure if there's a fix return None class OpenAICausalLMEvalWrapper(OpenAIEvalInterface): def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: super().__init__(model_cfg, tokenizer) try: import openai except ImportError as e: raise MissingConditionalImportError( extra_deps_group='openai', conda_package='openai', conda_channel='conda-forge') from e self.generate_completion = lambda prompt, num_tokens: openai.Completion.create( engine=self.model_name, prompt=prompt, max_tokens=1, logprobs=5, temperature=0.0) def process_result(self, completion: Optional[dict]): if completion is None: raise ValueError("Couldn't generate model output") assert isinstance(completion, dict) if len(completion['choices'][0]['logprobs']['top_logprobs']) > 0: tensor = self.tokenizer.construct_logit_tensor( dict(completion['choices'][0]['logprobs']['top_logprobs'][0])) return tensor else: # the model sometimes stops early even though we are still requesting tokens! # not sure if there's a fix return None