File size: 9,456 Bytes
de4ade4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
"""Implements a OpenAI chat and causal LM inference API wrappers."""
import logging
import os
from time import sleep
from typing import Any, Dict, List, Optional, Union
import torch
from composer.core.types import Batch
from composer.utils.import_helpers import MissingConditionalImportError
from transformers import AutoTokenizer
log = logging.getLogger(__name__)
from llmfoundry.models.inference_api_wrapper.interface import \
InferenceAPIEvalWrapper
__all__ = [
'OpenAICausalLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
]
MAX_RETRIES = 10
class OpenAIEvalInterface(InferenceAPIEvalWrapper):
def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
try:
import openai
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e
openai.api_key = os.getenv('OPENAI_API_KEY')
self.model_name = model_cfg['version']
def generate_completion(self, prompt: str, num_tokens: int):
raise NotImplementedError()
def process_result(self, completion: Optional[dict]):
raise NotImplementedError()
def get_next_token_logit_tensor(self, prompt: str, num_tokens: int = 1):
completion = self.try_generate_completion(prompt, num_tokens)
return self.process_result(completion)
def try_generate_completion(self, prompt: str, num_tokens: int):
try:
from openai.error import RateLimitError
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e
tries = 0
completion = None
while tries < MAX_RETRIES:
tries += 1
try:
completion = self.generate_completion(prompt, num_tokens)
break
except RateLimitError as e:
if 'You exceeded your current quota' in str(e._message):
raise e
sleep(60)
continue
except Exception:
continue
return completion
class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface):
def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
try:
import openai
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e
self.generate_completion = lambda prompt, num_tokens: openai.ChatCompletion.create(
self.model_name,
messages=[{
'role': 'user',
'content': prompt
}],
max_tokens=num_tokens,
temperature=0.0)
def retokenize(self, tokens: List[int], cont_idxs: List[int]):
"""Chat API will never respond with a word-initial space.
If the continuation tokens begin with a word initial space, we need to
re-tokenize with the space removed.
"""
original_len = len(tokens)
retokenized_continuation = self.tokenizer(
self.tokenizer.decode(tokens[cont_idxs[0]:cont_idxs[-1] +
1]).strip())['input_ids']
# replace the original continuation with the retokenized continuation + padding
padding = [tokens[-1]] * (
len(tokens) - len(tokens[:cont_idxs[0]] + retokenized_continuation))
tokens = tokens[:cont_idxs[0]] + retokenized_continuation + padding
if len(tokens) > original_len:
# this only happens if we were already at max seq len and the continuation got LARGER
tokens = tokens[-original_len:]
cont_idxs = list(
range(original_len - len(retokenized_continuation),
original_len))
else:
cont_idxs = list(
range(cont_idxs[0],
cont_idxs[0] + len(retokenized_continuation)))
return torch.tensor(tokens), torch.tensor(cont_idxs)
def rebatch(self, batch: Batch):
"""Chat API tokenization has different behavior than GPT3.
Model responses will never begin with spaces even if the continuation is
expected to, so we need to retokenize the input to account for that.
"""
new_batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = {
'input_ids': [],
'continuation_indices': [],
'labels': []
}
for tokens, cont_idxs in zip(batch['input_ids'],
batch['continuation_indices']):
tokens, cont_idxs = self.retokenize(tokens.tolist(),
cont_idxs.tolist())
assert isinstance(new_batch['input_ids'], list)
new_batch['input_ids'].append(tokens)
assert isinstance(new_batch['labels'], list)
new_batch['labels'].append(tokens)
assert isinstance(new_batch['continuation_indices'], list)
new_batch['continuation_indices'].append(cont_idxs)
new_batch.update({
k: torch.stack(new_batch[k]) # pyright: ignore
for k in ['input_ids', 'labels']
})
new_batch.update({k: v for k, v in batch.items() if k not in new_batch})
return new_batch
def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
# Override the base class because Chat's API always strips spacing from model outputs resulting in different tokens
# than what the continuation would expect.
# Get around this issue by retokenizing the batch to remove spacing from the continuation as well as
# decoding the whole continuation at once.
output_logits_batch = []
batch = self.rebatch(batch)
for tokens, cont_idxs in zip(batch['input_ids'],
batch['continuation_indices']):
seqlen = tokens.shape[0]
tokens = tokens.tolist()
cont_idxs = cont_idxs.tolist()
expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]
output_logits = torch.nn.functional.one_hot(
torch.tensor(tokens[1:cont_idxs[0]]),
num_classes=self.tokenizer.vocab_size)
prompt = self.tokenizer.decode(tokens[:cont_idxs[0]])
next_logit_tensor = self.get_next_token_logit_tensor(
prompt, num_tokens=len(expected_cont_tokens))
if next_logit_tensor is not None:
output_logits = torch.cat([output_logits, next_logit_tensor])
padding = torch.nn.functional.one_hot(
torch.full((seqlen - output_logits.shape[0],),
self.tokenizer.pad_token_id),
num_classes=self.tokenizer.vocab_size)
output_logits = torch.cat([output_logits, padding])
output_logits_batch.append(output_logits)
return torch.stack(output_logits_batch).to(batch['input_ids'].device)
def process_result(self, completion: Optional[dict]):
assert isinstance(completion, dict)
if len(completion['choices']) > 0:
tensors = []
for t in self.tokenizer(completion['choices'][0]['message']
['content'])['input_ids']:
tensors.append(
self.tokenizer.construct_logit_tensor(
{self.tokenizer.decode([t]): 0.0}))
if len(tensors) == 0:
return None
return torch.stack(tensors)
else:
# the model sometimes stops early even though we are still requesting tokens!
# not sure if there's a fix
return None
class OpenAICausalLMEvalWrapper(OpenAIEvalInterface):
def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
try:
import openai
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e
self.generate_completion = lambda prompt, num_tokens: openai.Completion.create(
engine=self.model_name,
prompt=prompt,
max_tokens=1,
logprobs=5,
temperature=0.0)
def process_result(self, completion: Optional[dict]):
if completion is None:
raise ValueError("Couldn't generate model output")
assert isinstance(completion, dict)
if len(completion['choices'][0]['logprobs']['top_logprobs']) > 0:
tensor = self.tokenizer.construct_logit_tensor(
dict(completion['choices'][0]['logprobs']['top_logprobs'][0]))
return tensor
else:
# the model sometimes stops early even though we are still requesting tokens!
# not sure if there's a fix
return None
|