|
import json |
|
import os |
|
import random |
|
import typing |
|
from argparse import ArgumentParser |
|
|
|
import openai |
|
from loguru import logger |
|
from tenacity import _utils, Retrying, retry_if_not_exception_type |
|
from tenacity.stop import stop_base |
|
from tenacity.wait import wait_base |
|
|
|
|
|
def my_before_sleep(retry_state): |
|
logger.debug( |
|
f"Retrying: attempt {retry_state.attempt_number} ended with: {retry_state.outcome}, spend {retry_state.seconds_since_start} in total" |
|
) |
|
|
|
|
|
class my_wait_exponential(wait_base): |
|
def __init__( |
|
self, |
|
multiplier: typing.Union[int, float] = 1, |
|
max: _utils.time_unit_type = _utils.MAX_WAIT, |
|
exp_base: typing.Union[int, float] = 2, |
|
min: _utils.time_unit_type = 0, |
|
) -> None: |
|
self.multiplier = multiplier |
|
self.min = _utils.to_seconds(min) |
|
self.max = _utils.to_seconds(max) |
|
self.exp_base = exp_base |
|
|
|
def __call__(self, retry_state: "RetryCallState") -> float: |
|
if retry_state.outcome == openai.error.Timeout: |
|
return 0 |
|
|
|
try: |
|
exp = self.exp_base ** (retry_state.attempt_number - 1) |
|
result = self.multiplier * exp |
|
except OverflowError: |
|
return self.max |
|
return max(max(0, self.min), min(result, self.max)) |
|
|
|
|
|
class my_stop_after_attempt(stop_base): |
|
"""Stop when the previous attempt >= max_attempt.""" |
|
|
|
def __init__(self, max_attempt_number: int) -> None: |
|
self.max_attempt_number = max_attempt_number |
|
|
|
def __call__(self, retry_state: "RetryCallState") -> bool: |
|
if retry_state.outcome == openai.error.Timeout: |
|
retry_state.attempt_number -= 1 |
|
return retry_state.attempt_number >= self.max_attempt_number |
|
|
|
|
|
def annotate(item_text_list): |
|
request_timeout = 6 |
|
for attempt in Retrying( |
|
reraise=True, |
|
retry=retry_if_not_exception_type( |
|
(openai.error.InvalidRequestError, openai.error.AuthenticationError) |
|
), |
|
wait=my_wait_exponential(min=1, max=60), |
|
stop=(my_stop_after_attempt(8)), |
|
before_sleep=my_before_sleep, |
|
): |
|
with attempt: |
|
response = openai.Embedding.create( |
|
model="text-embedding-ada-002", |
|
input=item_text_list, |
|
request_timeout=request_timeout, |
|
) |
|
request_timeout = min(30, request_timeout * 2) |
|
|
|
return response |
|
|
|
|
|
def get_exist_item_set(): |
|
exist_item_set = set() |
|
for file in os.listdir(save_dir): |
|
user_id = os.path.splitext(file)[0] |
|
exist_item_set.add(user_id) |
|
return exist_item_set |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser() |
|
parser.add_argument("--api_key") |
|
parser.add_argument("--batch_size", default=1, type=int) |
|
parser.add_argument("--dataset", type=str, choices=["redial", "opendialkg"]) |
|
args = parser.parse_args() |
|
|
|
openai.api_key = args.api_key |
|
batch_size = args.batch_size |
|
dataset = args.dataset |
|
|
|
save_dir = f"../save/embed/item/{dataset}" |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
with open(f"../data/{dataset}/id2info.json", encoding="utf-8") as f: |
|
id2info = json.load(f) |
|
|
|
|
|
if dataset == "redial": |
|
info_list = list(id2info.values()) |
|
item_texts = [] |
|
for info in info_list: |
|
item_text_list = [ |
|
f"Title: {info['name']}", |
|
f"Genre: {', '.join(info['genre']).lower()}", |
|
f"Star: {', '.join(info['star'])}", |
|
f"Director: {', '.join(info['director'])}", |
|
f"Plot: {info['plot']}", |
|
] |
|
item_text = "; ".join(item_text_list) |
|
item_texts.append(item_text) |
|
attr_list = ["genre", "star", "director"] |
|
|
|
|
|
if dataset == "opendialkg": |
|
item_texts = [] |
|
for info_dict in id2info.values(): |
|
item_attr_list = [f'Name: {info_dict["name"]}'] |
|
for attr, value_list in info_dict.items(): |
|
if attr != "title": |
|
item_attr_list.append( |
|
f"{attr.capitalize()}: " + ", ".join(value_list) |
|
) |
|
item_text = "; ".join(item_attr_list) |
|
item_texts.append(item_text) |
|
attr_list = ["genre", "actor", "director", "writer"] |
|
|
|
id2text = {} |
|
for item_id, info_dict in id2info.items(): |
|
attr_str_list = [f'Title: {info_dict["name"]}'] |
|
for attr in attr_list: |
|
if attr not in info_dict: |
|
continue |
|
if isinstance(info_dict[attr], list): |
|
value_str = ", ".join(info_dict[attr]) |
|
else: |
|
value_str = info_dict[attr] |
|
attr_str_list.append(f"{attr.capitalize()}: {value_str}") |
|
item_text = "; ".join(attr_str_list) |
|
id2text[item_id] = item_text |
|
|
|
item_ids = set(id2info.keys()) - get_exist_item_set() |
|
while len(item_ids) > 0: |
|
logger.info(len(item_ids)) |
|
|
|
|
|
if dataset == "redial": |
|
batch_item_ids = random.sample( |
|
tuple(item_ids), min(batch_size, len(item_ids)) |
|
) |
|
batch_texts = [id2text[item_id] for item_id in batch_item_ids] |
|
|
|
|
|
if dataset == "opendialkg": |
|
batch_item_ids = random.sample( |
|
tuple(item_ids), min(batch_size, len(item_ids)) |
|
) |
|
batch_texts = [id2text[item_id] for item_id in batch_item_ids] |
|
|
|
batch_embeds = annotate(batch_texts)["data"] |
|
for embed in batch_embeds: |
|
item_id = batch_item_ids[embed["index"]] |
|
with open(f"{save_dir}/{item_id}.json", "w", encoding="utf-8") as f: |
|
json.dump(embed["embedding"], f, ensure_ascii=False) |
|
|
|
item_ids -= get_exist_item_set() |
|
|