File size: 9,456 Bytes
de4ade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Implements a OpenAI chat and causal LM inference API wrappers."""

import logging
import os
from time import sleep
from typing import Any, Dict, List, Optional, Union

import torch
from composer.core.types import Batch
from composer.utils.import_helpers import MissingConditionalImportError
from transformers import AutoTokenizer

log = logging.getLogger(__name__)

from llmfoundry.models.inference_api_wrapper.interface import \
    InferenceAPIEvalWrapper

__all__ = [
    'OpenAICausalLMEvalWrapper',
    'OpenAIChatAPIEvalWrapper',
]

MAX_RETRIES = 10


class OpenAIEvalInterface(InferenceAPIEvalWrapper):

    def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
        super().__init__(model_cfg, tokenizer)
        try:
            import openai
        except ImportError as e:
            raise MissingConditionalImportError(
                extra_deps_group='openai',
                conda_package='openai',
                conda_channel='conda-forge') from e
        openai.api_key = os.getenv('OPENAI_API_KEY')
        self.model_name = model_cfg['version']

    def generate_completion(self, prompt: str, num_tokens: int):
        raise NotImplementedError()

    def process_result(self, completion: Optional[dict]):
        raise NotImplementedError()

    def get_next_token_logit_tensor(self, prompt: str, num_tokens: int = 1):
        completion = self.try_generate_completion(prompt, num_tokens)
        return self.process_result(completion)

    def try_generate_completion(self, prompt: str, num_tokens: int):
        try:
            from openai.error import RateLimitError
        except ImportError as e:
            raise MissingConditionalImportError(
                extra_deps_group='openai',
                conda_package='openai',
                conda_channel='conda-forge') from e
        tries = 0
        completion = None
        while tries < MAX_RETRIES:
            tries += 1
            try:

                completion = self.generate_completion(prompt, num_tokens)
                break
            except RateLimitError as e:
                if 'You exceeded your current quota' in str(e._message):
                    raise e
                sleep(60)
                continue
            except Exception:
                continue
        return completion


class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface):

    def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
        super().__init__(model_cfg, tokenizer)
        try:
            import openai
        except ImportError as e:
            raise MissingConditionalImportError(
                extra_deps_group='openai',
                conda_package='openai',
                conda_channel='conda-forge') from e

        self.generate_completion = lambda prompt, num_tokens: openai.ChatCompletion.create(
            self.model_name,
            messages=[{
                'role': 'user',
                'content': prompt
            }],
            max_tokens=num_tokens,
            temperature=0.0)

    def retokenize(self, tokens: List[int], cont_idxs: List[int]):
        """Chat API will never respond with a word-initial space.

        If the continuation tokens begin with a word initial space, we need to
        re-tokenize with the space removed.
        """
        original_len = len(tokens)
        retokenized_continuation = self.tokenizer(
            self.tokenizer.decode(tokens[cont_idxs[0]:cont_idxs[-1] +
                                         1]).strip())['input_ids']

        # replace the original continuation with the retokenized continuation + padding
        padding = [tokens[-1]] * (
            len(tokens) - len(tokens[:cont_idxs[0]] + retokenized_continuation))
        tokens = tokens[:cont_idxs[0]] + retokenized_continuation + padding

        if len(tokens) > original_len:
            # this only happens if we were already at max seq len and the continuation got LARGER
            tokens = tokens[-original_len:]
            cont_idxs = list(
                range(original_len - len(retokenized_continuation),
                      original_len))
        else:
            cont_idxs = list(
                range(cont_idxs[0],
                      cont_idxs[0] + len(retokenized_continuation)))
        return torch.tensor(tokens), torch.tensor(cont_idxs)

    def rebatch(self, batch: Batch):
        """Chat API tokenization has different behavior than GPT3.

        Model responses will never begin with spaces even if the continuation is
        expected to, so we need to retokenize the input to account for that.
        """
        new_batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = {
            'input_ids': [],
            'continuation_indices': [],
            'labels': []
        }
        for tokens, cont_idxs in zip(batch['input_ids'],
                                     batch['continuation_indices']):
            tokens, cont_idxs = self.retokenize(tokens.tolist(),
                                                cont_idxs.tolist())

            assert isinstance(new_batch['input_ids'], list)
            new_batch['input_ids'].append(tokens)
            assert isinstance(new_batch['labels'], list)
            new_batch['labels'].append(tokens)
            assert isinstance(new_batch['continuation_indices'], list)
            new_batch['continuation_indices'].append(cont_idxs)

        new_batch.update({
            k: torch.stack(new_batch[k])  # pyright: ignore
            for k in ['input_ids', 'labels']
        })

        new_batch.update({k: v for k, v in batch.items() if k not in new_batch})

        return new_batch

    def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
        # Override the base class because Chat's API always strips spacing from model outputs resulting in different tokens
        # than what the continuation would expect.
        # Get around this issue by retokenizing the batch to remove spacing from the continuation as well as
        # decoding the whole continuation at once.
        output_logits_batch = []
        batch = self.rebatch(batch)
        for tokens, cont_idxs in zip(batch['input_ids'],
                                     batch['continuation_indices']):

            seqlen = tokens.shape[0]
            tokens = tokens.tolist()
            cont_idxs = cont_idxs.tolist()
            expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]
            output_logits = torch.nn.functional.one_hot(
                torch.tensor(tokens[1:cont_idxs[0]]),
                num_classes=self.tokenizer.vocab_size)

            prompt = self.tokenizer.decode(tokens[:cont_idxs[0]])
            next_logit_tensor = self.get_next_token_logit_tensor(
                prompt, num_tokens=len(expected_cont_tokens))

            if next_logit_tensor is not None:
                output_logits = torch.cat([output_logits, next_logit_tensor])
            padding = torch.nn.functional.one_hot(
                torch.full((seqlen - output_logits.shape[0],),
                           self.tokenizer.pad_token_id),
                num_classes=self.tokenizer.vocab_size)
            output_logits = torch.cat([output_logits, padding])
            output_logits_batch.append(output_logits)

        return torch.stack(output_logits_batch).to(batch['input_ids'].device)

    def process_result(self, completion: Optional[dict]):
        assert isinstance(completion, dict)
        if len(completion['choices']) > 0:
            tensors = []
            for t in self.tokenizer(completion['choices'][0]['message']
                                    ['content'])['input_ids']:
                tensors.append(
                    self.tokenizer.construct_logit_tensor(
                        {self.tokenizer.decode([t]): 0.0}))

            if len(tensors) == 0:
                return None
            return torch.stack(tensors)
        else:
            # the model sometimes stops early even though we are still requesting tokens!
            # not sure if there's a fix
            return None


class OpenAICausalLMEvalWrapper(OpenAIEvalInterface):

    def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
        super().__init__(model_cfg, tokenizer)
        try:
            import openai
        except ImportError as e:
            raise MissingConditionalImportError(
                extra_deps_group='openai',
                conda_package='openai',
                conda_channel='conda-forge') from e

        self.generate_completion = lambda prompt, num_tokens: openai.Completion.create(
            engine=self.model_name,
            prompt=prompt,
            max_tokens=1,
            logprobs=5,
            temperature=0.0)

    def process_result(self, completion: Optional[dict]):
        if completion is None:
            raise ValueError("Couldn't generate model output")

        assert isinstance(completion, dict)
        if len(completion['choices'][0]['logprobs']['top_logprobs']) > 0:
            tensor = self.tokenizer.construct_logit_tensor(
                dict(completion['choices'][0]['logprobs']['top_logprobs'][0]))
            return tensor
        else:
            # the model sometimes stops early even though we are still requesting tokens!
            # not sure if there's a fix
            return None