CRSArena / src /model /utils.py
Nol00's picture
Update src/model/utils.py
00f8114 verified
raw
history blame
6.33 kB
import json
import random
from typing import Dict, List, Optional, Tuple, Union
import torch
from rapidfuzz import fuzz, process
from torch import nn
from torch.nn import functional as F
special_tokens_dict = {"pad_token": "<|pad|>"}
def load_jsonl_data(file):
data_list = []
with open(file, encoding="utf-8") as f:
for line in f:
data = json.loads(line)
data_list.append(data)
return data_list
def simple_collate(batch):
return batch
def sample_data(data_list, shot=1, debug=False, number_for_debug=320):
if debug:
data_list = data_list[:number_for_debug]
if shot < 1:
data_idx = random.sample(
range(len(data_list)), int(len(data_list) * shot)
)
data_list = [data_list[idx] for idx in data_idx]
elif shot > 1:
data_idx = range(int(shot))
data_list = [data_list[idx] for idx in data_idx]
return data_list
def padded_tensor(
items: List[Union[List[int], torch.LongTensor]],
pad_id: int = 0,
pad_tail: bool = True,
device: torch.device = torch.device("cpu"),
debug: bool = False,
max_length: Optional[int] = None,
) -> torch.Tensor:
# number of items
n = len(items)
# length of each item
lens: List[int] = [len(item) for item in items]
# max in time dimension
t = max(max(lens), 1)
if debug and max_length is not None:
t = max(t, max_length)
output = torch.full(
(n, t), fill_value=pad_id, dtype=torch.long, device=device
)
for i, (item, length) in enumerate(zip(items, lens)):
if length == 0:
continue
if not isinstance(item, torch.Tensor):
item = torch.as_tensor(item, dtype=torch.long, device=device)
if pad_tail:
output[i, :length] = item
else:
output[i, t - length :] = item
return output
class SelfAttention(nn.Module):
def __init__(self, hidden_size):
super(SelfAttention, self).__init__()
self.attn = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1),
)
def forward(self, x, mask=None):
"""
Args:
x (bs, seq_len, hs)
mask (bs, seq_len): False for masked token.
Returns:
(bs, hs)
"""
attn = self.attn(x) # (bs, seq_len, 1)
if mask is not None:
attn += (~mask).unsqueeze(-1) * -1e4
attn = F.softmax(attn, dim=-1)
x = attn.transpose(1, 2) @ x # (bs, 1, hs)
x = x.squeeze(1)
return x
def shift_tokens_right(
input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].detach().clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
# dbpedia get entity
# def get_entity(text, SPOTLIGHT_CONFIDENCE):
# DBPEDIA_SPOTLIGHT_ADDR = " http://0.0.0.0:2222/rest/annotate"
# headers = {"accept": "application/json"}
# params = {"text": text, "confidence": SPOTLIGHT_CONFIDENCE}
# response = requests.get(DBPEDIA_SPOTLIGHT_ADDR, headers=headers, params=params)
# response = response.json()
# return (
# [f"<{x['@URI']}>" for x in response["Resources"]]
# if "Resources" in response
# else []
# )
# rapidfuzz get entity
def get_entity(text, entity_list):
extractions = process.extract(
text, entity_list, scorer=fuzz.WRatio, limit=20
)
extractions = [
extraction[0] for extraction in extractions if extraction[1] >= 90
]
return extractions
def get_options(dataset: str) -> Tuple[str, Dict[str, str]]:
"""Returns the possible options for a given dataset.
Args:
dataset: The dataset to get options for.
Raises:
ValueError: If the dataset is not supported.
Returns:
A tuple containing the prompt and a dictionary of options.
"""
if "redial" in dataset:
instructions = (
"To recommend me items that I will accept, you can choose one of "
"the following options.\nA: ask my preference for genre\nB: ask my "
"preference for actor\nC: ask my preference for director\nD: I can "
"directly give recommendations\nPlease enter the option character. "
"Please only response a character."
)
options = {
"A": {"attribute": "genre", "template": "What genre do you like?"},
"B": {"attribute": "actor", "template": "Which star do you like?"},
"C": {
"attribute": "director",
"template": "Which director do you like?",
},
"D": {"attribute": "recommend", "template": ""},
}
return instructions, options
elif "opendialkg" in dataset:
instructions = (
"To recommend me items that I will accept, you can choose one of "
"the following options.\nA: ask my preference for genre\nB: ask my "
"preference for actor\nC: ask my preference for director\nD: ask "
"my preference for writer\nE: I can directly give recommendations"
"\nPlease enter the option character. Please only response a "
"character."
)
options = {
"A": {"attribute": "genre", "template": "What genre do you like?"},
"B": {"attribute": "actor", "template": "Which star do you like?"},
"C": {
"attribute": "director",
"template": "Which director do you like?",
},
"D": {
"attribute": "writer",
"template": "Which writer do you like?",
},
"E": {"attribute": "recommend", "template": ""},
}
return instructions, options
raise ValueError(f"Dataset {dataset} is not supported.")