File size: 2,504 Bytes
0c3992e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import re
import torch
import openai
from functools import partial
import time
import multiprocessing


def get_openai_embedding(text, 
                         model="text-embedding-ada-002",
                         max_retry=1,
                         sleep_time=0):
    assert isinstance(text, str), f'text must be str, but got {type(text)}'
    assert len(text) > 0, f'text to be embedded should be non-empty'
    client = openai.OpenAI()
    for _ in range(max_retry):
        try:
            emb = client.embeddings.create(input=[text], model=model)
            return torch.FloatTensor(emb.data[0].embedding).view(1, -1)
        except openai.BadRequestError as e:
            print(f'{e}')
            e = str(e)
            ori_length = len(text.split(' '))
            match = re.search(r'maximum context length is (\d+) tokens, however you requested (\d+) tokens', e)
            if match is not None:
                max_length = int(match.group(1))
                cur_length = int(match.group(2))
                ratio = float(max_length) / cur_length
                for reduce_rate in range(9, 0, -1):
                    shorten_text = text.split(' ')
                    length = int(ratio * ori_length * (reduce_rate * 0.1))
                    shorten_text = ' '.join(shorten_text[:length])
                    try:
                        emb = client.embeddings.create(input=[shorten_text], model=model)
                        print(f'length={length} works! reduce_rate={0.1 * reduce_rate}.')
                        return torch.FloatTensor(emb.data[0].embedding).view(1, -1)
                    except: 
                        continue
        except (openai.RateLimitError, openai.APITimeoutError) as e:
            print(f'{e}, sleep for 1 min')
            time.sleep(sleep_time)


def get_openai_embeddings(texts, 
                          n_max_nodes=5, 
                          model="text-embedding-ada-002"
                          ):
    """ 
    Get embeddings for a list of texts.
    """
    assert isinstance(texts, list), f'texts must be list, but got {type(texts)}'
    assert all([len(s) > 0 for s in texts]), f'every string in the `texts` list to be embedded should be non-empty'

    processes = min(len(texts), n_max_nodes)
    ada_encoder = partial(get_openai_embedding, model=model)
    with multiprocessing.Pool(processes=processes) as pool:
        results = pool.map(ada_encoder, texts)

    results = torch.cat(results, dim=0)
    return results