|
import json
|
|
import random
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from rapidfuzz import fuzz, process
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
special_tokens_dict = {"pad_token": "<|pad|>"}
|
|
|
|
|
|
def load_jsonl_data(file):
|
|
data_list = []
|
|
with open(file, encoding="utf-8") as f:
|
|
for line in f:
|
|
data = json.loads(line)
|
|
data_list.append(data)
|
|
return data_list
|
|
|
|
|
|
def simple_collate(batch):
|
|
return batch
|
|
|
|
|
|
def sample_data(data_list, shot=1, debug=False, number_for_debug=320):
|
|
if debug:
|
|
data_list = data_list[:number_for_debug]
|
|
|
|
if shot < 1:
|
|
data_idx = random.sample(
|
|
range(len(data_list)), int(len(data_list) * shot)
|
|
)
|
|
data_list = [data_list[idx] for idx in data_idx]
|
|
elif shot > 1:
|
|
data_idx = range(int(shot))
|
|
data_list = [data_list[idx] for idx in data_idx]
|
|
|
|
return data_list
|
|
|
|
|
|
def padded_tensor(
|
|
items: List[Union[List[int], torch.LongTensor]],
|
|
pad_id: int = 0,
|
|
pad_tail: bool = True,
|
|
device: torch.device = torch.device("cpu"),
|
|
debug: bool = False,
|
|
max_length: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
|
|
n = len(items)
|
|
|
|
lens: List[int] = [len(item) for item in items]
|
|
|
|
t = max(max(lens), 1)
|
|
if debug and max_length is not None:
|
|
t = max(t, max_length)
|
|
|
|
output = torch.full(
|
|
(n, t), fill_value=pad_id, dtype=torch.long, device=device
|
|
)
|
|
|
|
for i, (item, length) in enumerate(zip(items, lens)):
|
|
if length == 0:
|
|
continue
|
|
if not isinstance(item, torch.Tensor):
|
|
item = torch.as_tensor(item, dtype=torch.long, device=device)
|
|
if pad_tail:
|
|
output[i, :length] = item
|
|
else:
|
|
output[i, t - length :] = item
|
|
|
|
return output
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
def __init__(self, hidden_size):
|
|
super(SelfAttention, self).__init__()
|
|
self.attn = nn.Sequential(
|
|
nn.Linear(hidden_size, hidden_size),
|
|
nn.Tanh(),
|
|
nn.Linear(hidden_size, 1),
|
|
)
|
|
|
|
def forward(self, x, mask=None):
|
|
"""
|
|
|
|
Args:
|
|
x (bs, seq_len, hs)
|
|
mask (bs, seq_len): False for masked token.
|
|
|
|
Returns:
|
|
(bs, hs)
|
|
"""
|
|
attn = self.attn(x)
|
|
if mask is not None:
|
|
attn += (~mask).unsqueeze(-1) * -1e4
|
|
attn = F.softmax(attn, dim=-1)
|
|
x = attn.transpose(1, 2) @ x
|
|
x = x.squeeze(1)
|
|
return x
|
|
|
|
|
|
def shift_tokens_right(
|
|
input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
|
|
):
|
|
"""
|
|
Shift input ids one token to the right.
|
|
"""
|
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
|
shifted_input_ids[:, 1:] = input_ids[:, :-1].detach().clone()
|
|
shifted_input_ids[:, 0] = decoder_start_token_id
|
|
|
|
if pad_token_id is None:
|
|
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
|
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
|
|
|
return shifted_input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_entity(text, entity_list):
|
|
extractions = process.extract(
|
|
text, entity_list, scorer=fuzz.WRatio, limit=20
|
|
)
|
|
extractions = [
|
|
extraction[0] for extraction in extractions if extraction[1] >= 90
|
|
]
|
|
return extractions
|
|
|
|
|
|
def get_options(dataset: str) -> Tuple[str, Dict[str, str]]:
|
|
"""Returns the possible options for a given dataset.
|
|
|
|
Args:
|
|
dataset: The dataset to get options for.
|
|
|
|
Raises:
|
|
ValueError: If the dataset is not supported.
|
|
|
|
Returns:
|
|
A tuple containing the prompt and a dictionary of options.
|
|
"""
|
|
if "redial" in dataset:
|
|
instructions = (
|
|
"To recommend me items that I will accept, you can choose one of "
|
|
"the following options.\nA: ask my preference for genre\nB: ask my "
|
|
"preference for actor\nC: ask my preference for director\nD: I can "
|
|
"directly give recommendations\nPlease enter the option character. "
|
|
"Please only response a character."
|
|
)
|
|
options = {
|
|
"A": {"attribute": "genre", "template": "What genre do you like?"},
|
|
"B": {"attribute": "actor", "template": "Which star do you like?"},
|
|
"C": {
|
|
"attribute": "director",
|
|
"template": "Which director do you like?",
|
|
},
|
|
"D": {"attribute": "recommend", "template": ""},
|
|
}
|
|
return instructions, options
|
|
elif "opendialkg" in dataset:
|
|
instructions = (
|
|
"To recommend me items that I will accept, you can choose one of "
|
|
"the following options.\nA: ask my preference for genre\nB: ask my "
|
|
"preference for actor\nC: ask my preference for director\nD: ask "
|
|
"my preference for writer\nE: I can directly give recommendations"
|
|
"\nPlease enter the option character. Please only response a "
|
|
"character."
|
|
)
|
|
options = {
|
|
"A": {"attribute": "genre", "template": "What genre do you like?"},
|
|
"B": {"attribute": "actor", "template": "Which star do you like?"},
|
|
"C": {
|
|
"attribute": "director",
|
|
"template": "Which director do you like?",
|
|
},
|
|
"D": {
|
|
"attribute": "writer",
|
|
"template": "Which writer do you like?",
|
|
},
|
|
}
|
|
return instructions, options
|
|
|
|
raise ValueError(f"Dataset {dataset} is not supported.")
|
|
|