File size: 6,334 Bytes
00f8114 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import json
import random
from typing import Dict, List, Optional, Tuple, Union
import torch
from rapidfuzz import fuzz, process
from torch import nn
from torch.nn import functional as F
special_tokens_dict = {"pad_token": "<|pad|>"}
def load_jsonl_data(file):
data_list = []
with open(file, encoding="utf-8") as f:
for line in f:
data = json.loads(line)
data_list.append(data)
return data_list
def simple_collate(batch):
return batch
def sample_data(data_list, shot=1, debug=False, number_for_debug=320):
if debug:
data_list = data_list[:number_for_debug]
if shot < 1:
data_idx = random.sample(
range(len(data_list)), int(len(data_list) * shot)
)
data_list = [data_list[idx] for idx in data_idx]
elif shot > 1:
data_idx = range(int(shot))
data_list = [data_list[idx] for idx in data_idx]
return data_list
def padded_tensor(
items: List[Union[List[int], torch.LongTensor]],
pad_id: int = 0,
pad_tail: bool = True,
device: torch.device = torch.device("cpu"),
debug: bool = False,
max_length: Optional[int] = None,
) -> torch.Tensor:
# number of items
n = len(items)
# length of each item
lens: List[int] = [len(item) for item in items]
# max in time dimension
t = max(max(lens), 1)
if debug and max_length is not None:
t = max(t, max_length)
output = torch.full(
(n, t), fill_value=pad_id, dtype=torch.long, device=device
)
for i, (item, length) in enumerate(zip(items, lens)):
if length == 0:
continue
if not isinstance(item, torch.Tensor):
item = torch.as_tensor(item, dtype=torch.long, device=device)
if pad_tail:
output[i, :length] = item
else:
output[i, t - length :] = item
return output
class SelfAttention(nn.Module):
def __init__(self, hidden_size):
super(SelfAttention, self).__init__()
self.attn = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1),
)
def forward(self, x, mask=None):
"""
Args:
x (bs, seq_len, hs)
mask (bs, seq_len): False for masked token.
Returns:
(bs, hs)
"""
attn = self.attn(x) # (bs, seq_len, 1)
if mask is not None:
attn += (~mask).unsqueeze(-1) * -1e4
attn = F.softmax(attn, dim=-1)
x = attn.transpose(1, 2) @ x # (bs, 1, hs)
x = x.squeeze(1)
return x
def shift_tokens_right(
input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].detach().clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
# dbpedia get entity
# def get_entity(text, SPOTLIGHT_CONFIDENCE):
# DBPEDIA_SPOTLIGHT_ADDR = " http://0.0.0.0:2222/rest/annotate"
# headers = {"accept": "application/json"}
# params = {"text": text, "confidence": SPOTLIGHT_CONFIDENCE}
# response = requests.get(DBPEDIA_SPOTLIGHT_ADDR, headers=headers, params=params)
# response = response.json()
# return (
# [f"<{x['@URI']}>" for x in response["Resources"]]
# if "Resources" in response
# else []
# )
# rapidfuzz get entity
def get_entity(text, entity_list):
extractions = process.extract(
text, entity_list, scorer=fuzz.WRatio, limit=20
)
extractions = [
extraction[0] for extraction in extractions if extraction[1] >= 90
]
return extractions
def get_options(dataset: str) -> Tuple[str, Dict[str, str]]:
"""Returns the possible options for a given dataset.
Args:
dataset: The dataset to get options for.
Raises:
ValueError: If the dataset is not supported.
Returns:
A tuple containing the prompt and a dictionary of options.
"""
if "redial" in dataset:
instructions = (
"To recommend me items that I will accept, you can choose one of "
"the following options.\nA: ask my preference for genre\nB: ask my "
"preference for actor\nC: ask my preference for director\nD: I can "
"directly give recommendations\nPlease enter the option character. "
"Please only response a character."
)
options = {
"A": {"attribute": "genre", "template": "What genre do you like?"},
"B": {"attribute": "actor", "template": "Which star do you like?"},
"C": {
"attribute": "director",
"template": "Which director do you like?",
},
"D": {"attribute": "recommend", "template": ""},
}
return instructions, options
elif "opendialkg" in dataset:
instructions = (
"To recommend me items that I will accept, you can choose one of "
"the following options.\nA: ask my preference for genre\nB: ask my "
"preference for actor\nC: ask my preference for director\nD: ask "
"my preference for writer\nE: I can directly give recommendations"
"\nPlease enter the option character. Please only response a "
"character."
)
options = {
"A": {"attribute": "genre", "template": "What genre do you like?"},
"B": {"attribute": "actor", "template": "Which star do you like?"},
"C": {
"attribute": "director",
"template": "Which director do you like?",
},
"D": {
"attribute": "writer",
"template": "Which writer do you like?",
},
"E": {"attribute": "recommend", "template": ""},
}
return instructions, options
raise ValueError(f"Dataset {dataset} is not supported.")
|