Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 2,504 Bytes
0c3992e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import re
import torch
import openai
from functools import partial
import time
import multiprocessing
def get_openai_embedding(text,
model="text-embedding-ada-002",
max_retry=1,
sleep_time=0):
assert isinstance(text, str), f'text must be str, but got {type(text)}'
assert len(text) > 0, f'text to be embedded should be non-empty'
client = openai.OpenAI()
for _ in range(max_retry):
try:
emb = client.embeddings.create(input=[text], model=model)
return torch.FloatTensor(emb.data[0].embedding).view(1, -1)
except openai.BadRequestError as e:
print(f'{e}')
e = str(e)
ori_length = len(text.split(' '))
match = re.search(r'maximum context length is (\d+) tokens, however you requested (\d+) tokens', e)
if match is not None:
max_length = int(match.group(1))
cur_length = int(match.group(2))
ratio = float(max_length) / cur_length
for reduce_rate in range(9, 0, -1):
shorten_text = text.split(' ')
length = int(ratio * ori_length * (reduce_rate * 0.1))
shorten_text = ' '.join(shorten_text[:length])
try:
emb = client.embeddings.create(input=[shorten_text], model=model)
print(f'length={length} works! reduce_rate={0.1 * reduce_rate}.')
return torch.FloatTensor(emb.data[0].embedding).view(1, -1)
except:
continue
except (openai.RateLimitError, openai.APITimeoutError) as e:
print(f'{e}, sleep for 1 min')
time.sleep(sleep_time)
def get_openai_embeddings(texts,
n_max_nodes=5,
model="text-embedding-ada-002"
):
"""
Get embeddings for a list of texts.
"""
assert isinstance(texts, list), f'texts must be list, but got {type(texts)}'
assert all([len(s) > 0 for s in texts]), f'every string in the `texts` list to be embedded should be non-empty'
processes = min(len(texts), n_max_nodes)
ada_encoder = partial(get_openai_embedding, model=model)
with multiprocessing.Pool(processes=processes) as pool:
results = pool.map(ada_encoder, texts)
results = torch.cat(results, dim=0)
return results
|