CRSArena / src /model /CHATGPT.py
Nol00's picture
Update src/model/CHATGPT.py
1703ab1 verified
import json
import os
from copy import deepcopy
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import openai
import tiktoken
from accelerate.utils import set_seed
from loguru import logger
from openai.types import CreateEmbeddingResponse
from sklearn.metrics.pairwise import cosine_similarity
from tenacity import Retrying, _utils, retry_if_not_exception_type
from tenacity.stop import stop_base
from tenacity.wait import wait_base
from tqdm import tqdm
def my_before_sleep(retry_state):
logger.debug(
f"Retrying: attempt {retry_state.attempt_number} ended with: "
f"{retry_state.outcome}, spend {retry_state.seconds_since_start} in "
"total"
)
class my_wait_exponential(wait_base):
def __init__(
self,
multiplier: Union[int, float] = 1,
max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa
exp_base: Union[int, float] = 2,
min: _utils.time_unit_type = 0, # noqa
) -> None:
self.multiplier = multiplier
self.min = _utils.to_seconds(min)
self.max = _utils.to_seconds(max)
self.exp_base = exp_base
def __call__(self, retry_state: "RetryCallState") -> float:
if retry_state.outcome == openai.Timeout:
return 0
try:
exp = self.exp_base ** (retry_state.attempt_number - 1)
result = self.multiplier * exp
except OverflowError:
return self.max
return max(max(0, self.min), min(result, self.max))
class my_stop_after_attempt(stop_base):
"""Stop when the previous attempt >= max_attempt."""
def __init__(self, max_attempt_number: int) -> None:
self.max_attempt_number = max_attempt_number
def __call__(self, retry_state: "RetryCallState") -> bool:
if retry_state.outcome == openai.Timeout:
retry_state.attempt_number -= 1
return retry_state.attempt_number >= self.max_attempt_number
def annotate(conv_str: str) -> CreateEmbeddingResponse:
"""Creates embeddings for the given conversation string."""
request_timeout = 6.0
for attempt in Retrying(
reraise=True,
retry=retry_if_not_exception_type(
(
openai.BadRequestError,
openai.AuthenticationError,
)
),
wait=my_wait_exponential(min=1, max=60),
stop=(my_stop_after_attempt(8)),
before_sleep=my_before_sleep,
):
with attempt:
response = openai.embeddings.create(
model="text-embedding-ada-002",
input=conv_str,
timeout=request_timeout,
)
request_timeout = min(30, request_timeout * 2)
return response
def annotate_chat(messages, logit_bias=None) -> str:
"""Generates a response given a conversation context.
Args:
messages: Conversation context (previous utterances).
logit_bias: Logit bias for the model.
Returns:
Generated response.
"""
request_timeout = 20.0
for attempt in Retrying(
reraise=True,
retry=retry_if_not_exception_type(
(
openai.BadRequestError,
openai.AuthenticationError,
)
),
wait=my_wait_exponential(min=1, max=60),
stop=(my_stop_after_attempt(8)),
before_sleep=my_before_sleep,
):
with attempt:
response = (
openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.0,
logit_bias=logit_bias,
timeout=request_timeout,
)
.choices[0]
.message.content
)
request_timeout = min(300, request_timeout * 2)
return response
class CHATGPT:
def __init__(self, seed, debug, kg_dataset) -> None:
self.seed = seed
self.debug = debug
if self.seed is not None:
set_seed(self.seed)
self.kg_dataset = kg_dataset
self.kg_dataset_path = f"data/{self.kg_dataset}"
with open(
f"{self.kg_dataset_path}/entity2id.json", "r", encoding="utf-8"
) as f:
self.entity2id = json.load(f)
with open(
f"{self.kg_dataset_path}/id2info.json", "r", encoding="utf-8"
) as f:
self.id2info = json.load(f)
self.id2entityid = {}
for id, info in self.id2info.items():
if info["name"] in self.entity2id:
self.id2entityid[id] = self.entity2id[info["name"]]
self.item_embedding_path = f"data/embed_items/{self.kg_dataset}"
item_emb_list = []
id2item_id = []
for i, file in tqdm(enumerate(os.listdir(self.item_embedding_path))):
item_id = os.path.splitext(file)[0]
if item_id in self.id2entityid:
id2item_id.append(item_id)
with open(
f"{self.item_embedding_path}/{file}", encoding="utf-8"
) as f:
embed = json.load(f)
item_emb_list.append(embed)
self.id2item_id_arr = np.asarray(id2item_id)
self.item_emb_arr = np.asarray(item_emb_list)
self.chat_recommender_instruction = (
"You are a recommender chatting with the user to provide "
"recommendation. You must follow the instructions below during "
"chat.\nIf you do not have enough information about user "
"preference, you should ask the user for his preference.\n"
"If you have enough information about user preference, you can "
"give recommendation. The recommendation list must contain 10 "
"items that are consistent with user preference. The "
"recommendation list can contain items that the dialog mentioned "
"before. The format of the recommendation list is: no. title. "
"Don't mention anything other than the title of items in your "
"recommendation list."
)
def get_rec(self, conv_dict):
rec_labels = [
self.entity2id[rec]
for rec in conv_dict["rec"]
if rec in self.entity2id
]
context = conv_dict["context"]
context_list = [] # for model
for i, text in enumerate(context):
if len(text) == 0:
continue
if i % 2 == 0:
role_str = "user"
else:
role_str = "assistant"
context_list.append({"role": role_str, "content": text})
conv_str = ""
for context in context_list[-2:]:
conv_str += f"{context['role']}: {context['content']} "
conv_embed = annotate(conv_str).data[0].embedding
conv_embed = np.asarray(conv_embed).reshape(1, -1)
sim_mat = cosine_similarity(conv_embed, self.item_emb_arr)
rank_arr = np.argsort(sim_mat, axis=-1).tolist()
rank_arr = np.flip(rank_arr, axis=-1)[:, :50]
item_rank_arr = self.id2item_id_arr[rank_arr].tolist()
item_rank_arr = [
[self.id2entityid[item_id] for item_id in item_rank_arr[0]]
]
return item_rank_arr, rec_labels
def get_conv(self, conv_dict):
context = conv_dict["context"]
context_list = [] # for model
context_list.append(
{"role": "system", "content": self.chat_recommender_instruction}
)
for i, text in enumerate(context):
if len(text) == 0:
continue
if i % 2 == 0:
role_str = "user"
else:
role_str = "assistant"
context_list.append({"role": role_str, "content": text})
gen_inputs = None
gen_str = annotate_chat(context_list)
return gen_inputs, gen_str
def get_choice(self, gen_inputs, options, state, conv_dict):
updated_options = []
for i, st in enumerate(state):
if st >= 0:
updated_options.append(options[i])
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
logit_bias = {
encoding.encode(option)[0]: 10 for option in updated_options
}
context = conv_dict["context"]
context_list = [] # for model
for i, text in enumerate(context[:-1]):
if len(text) == 0:
continue
if i % 2 == 0:
role_str = "user"
else:
role_str = "assistant"
context_list.append({"role": role_str, "content": text})
context_list.append({"role": "user", "content": context[-1]})
response_op = annotate_chat(context_list, logit_bias=logit_bias)
return response_op[0]
def get_response(
self,
conv_dict: Dict[str, Any],
id2entity: Dict[int, str],
options: Tuple[str, Dict[str, str]],
state: List[float],
) -> Tuple[str, List[float]]:
"""Generates a response given a conversation context.
Args:
conv_dict: Conversation context.
id2entity: Mapping from entity id to entity name.
options: Prompt with options and dictionary of options.
state: State of the option choices.
Returns:
Generated response and updated state.
"""
initial_conv_dict = deepcopy(conv_dict)
conv_dict["context"].append(options[0])
generated_inputs, generated_response = self.get_conv(conv_dict)
options_letter = list(options[1].keys())
# Get the choice between recommend and generate
choice = self.get_choice(
generated_inputs, options_letter, state, conv_dict
)
if choice == options_letter[-1]:
# Generate a recommendation
recommended_items, _ = self.get_rec(conv_dict)
recommended_items_str = ""
for i, item_id in enumerate(recommended_items[0][:3]):
recommended_items_str += f"{i+1}: {id2entity[item_id]} \n"
response = (
"I would recommend the following items: \n"
f"{recommended_items_str}"
)
else:
# Original : Generate a response to ask for preferences. The
# fallback is to use the generated response.
# response = (
# options[1].get(choice, {}).get("template", generated_response)
# )
# Generate response with original context otherwise generated
# response is the option's letter.
_, generated_response = self.get_conv(initial_conv_dict)
response = generated_response
# Update the state. Hack: penalize the choice to reduce the
# likelihood of selecting the same choice again
state[options_letter.index(choice)] = -1e5
return response, state