Nol00 commited on
Commit
00f8114
·
verified ·
1 Parent(s): 1fd9723

Update src/model/utils.py

Browse files
Files changed (1) hide show
  1. src/model/utils.py +202 -201
src/model/utils.py CHANGED
@@ -1,201 +1,202 @@
1
- import json
2
- import random
3
- from typing import Dict, List, Optional, Tuple, Union
4
-
5
- import torch
6
- from rapidfuzz import fuzz, process
7
- from torch import nn
8
- from torch.nn import functional as F
9
-
10
- special_tokens_dict = {"pad_token": "<|pad|>"}
11
-
12
-
13
- def load_jsonl_data(file):
14
- data_list = []
15
- with open(file, encoding="utf-8") as f:
16
- for line in f:
17
- data = json.loads(line)
18
- data_list.append(data)
19
- return data_list
20
-
21
-
22
- def simple_collate(batch):
23
- return batch
24
-
25
-
26
- def sample_data(data_list, shot=1, debug=False, number_for_debug=320):
27
- if debug:
28
- data_list = data_list[:number_for_debug]
29
-
30
- if shot < 1:
31
- data_idx = random.sample(
32
- range(len(data_list)), int(len(data_list) * shot)
33
- )
34
- data_list = [data_list[idx] for idx in data_idx]
35
- elif shot > 1:
36
- data_idx = range(int(shot))
37
- data_list = [data_list[idx] for idx in data_idx]
38
-
39
- return data_list
40
-
41
-
42
- def padded_tensor(
43
- items: List[Union[List[int], torch.LongTensor]],
44
- pad_id: int = 0,
45
- pad_tail: bool = True,
46
- device: torch.device = torch.device("cpu"),
47
- debug: bool = False,
48
- max_length: Optional[int] = None,
49
- ) -> torch.Tensor:
50
- # number of items
51
- n = len(items)
52
- # length of each item
53
- lens: List[int] = [len(item) for item in items]
54
- # max in time dimension
55
- t = max(max(lens), 1)
56
- if debug and max_length is not None:
57
- t = max(t, max_length)
58
-
59
- output = torch.full(
60
- (n, t), fill_value=pad_id, dtype=torch.long, device=device
61
- )
62
-
63
- for i, (item, length) in enumerate(zip(items, lens)):
64
- if length == 0:
65
- continue
66
- if not isinstance(item, torch.Tensor):
67
- item = torch.as_tensor(item, dtype=torch.long, device=device)
68
- if pad_tail:
69
- output[i, :length] = item
70
- else:
71
- output[i, t - length :] = item
72
-
73
- return output
74
-
75
-
76
- class SelfAttention(nn.Module):
77
- def __init__(self, hidden_size):
78
- super(SelfAttention, self).__init__()
79
- self.attn = nn.Sequential(
80
- nn.Linear(hidden_size, hidden_size),
81
- nn.Tanh(),
82
- nn.Linear(hidden_size, 1),
83
- )
84
-
85
- def forward(self, x, mask=None):
86
- """
87
-
88
- Args:
89
- x (bs, seq_len, hs)
90
- mask (bs, seq_len): False for masked token.
91
-
92
- Returns:
93
- (bs, hs)
94
- """
95
- attn = self.attn(x) # (bs, seq_len, 1)
96
- if mask is not None:
97
- attn += (~mask).unsqueeze(-1) * -1e4
98
- attn = F.softmax(attn, dim=-1)
99
- x = attn.transpose(1, 2) @ x # (bs, 1, hs)
100
- x = x.squeeze(1)
101
- return x
102
-
103
-
104
- def shift_tokens_right(
105
- input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
106
- ):
107
- """
108
- Shift input ids one token to the right.
109
- """
110
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
111
- shifted_input_ids[:, 1:] = input_ids[:, :-1].detach().clone()
112
- shifted_input_ids[:, 0] = decoder_start_token_id
113
-
114
- if pad_token_id is None:
115
- raise ValueError("self.model.config.pad_token_id has to be defined.")
116
- # replace possible -100 values in labels by `pad_token_id`
117
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
118
-
119
- return shifted_input_ids
120
-
121
-
122
- # dbpedia get entity
123
- # def get_entity(text, SPOTLIGHT_CONFIDENCE):
124
- # DBPEDIA_SPOTLIGHT_ADDR = " http://0.0.0.0:2222/rest/annotate"
125
- # headers = {"accept": "application/json"}
126
- # params = {"text": text, "confidence": SPOTLIGHT_CONFIDENCE}
127
-
128
- # response = requests.get(DBPEDIA_SPOTLIGHT_ADDR, headers=headers, params=params)
129
- # response = response.json()
130
- # return (
131
- # [f"<{x['@URI']}>" for x in response["Resources"]]
132
- # if "Resources" in response
133
- # else []
134
- # )
135
-
136
-
137
- # rapidfuzz get entity
138
- def get_entity(text, entity_list):
139
- extractions = process.extract(
140
- text, entity_list, scorer=fuzz.WRatio, limit=20
141
- )
142
- extractions = [
143
- extraction[0] for extraction in extractions if extraction[1] >= 90
144
- ]
145
- return extractions
146
-
147
-
148
- def get_options(dataset: str) -> Tuple[str, Dict[str, str]]:
149
- """Returns the possible options for a given dataset.
150
-
151
- Args:
152
- dataset: The dataset to get options for.
153
-
154
- Raises:
155
- ValueError: If the dataset is not supported.
156
-
157
- Returns:
158
- A tuple containing the prompt and a dictionary of options.
159
- """
160
- if "redial" in dataset:
161
- instructions = (
162
- "To recommend me items that I will accept, you can choose one of "
163
- "the following options.\nA: ask my preference for genre\nB: ask my "
164
- "preference for actor\nC: ask my preference for director\nD: I can "
165
- "directly give recommendations\nPlease enter the option character. "
166
- "Please only response a character."
167
- )
168
- options = {
169
- "A": {"attribute": "genre", "template": "What genre do you like?"},
170
- "B": {"attribute": "actor", "template": "Which star do you like?"},
171
- "C": {
172
- "attribute": "director",
173
- "template": "Which director do you like?",
174
- },
175
- "D": {"attribute": "recommend", "template": ""},
176
- }
177
- return instructions, options
178
- elif "opendialkg" in dataset:
179
- instructions = (
180
- "To recommend me items that I will accept, you can choose one of "
181
- "the following options.\nA: ask my preference for genre\nB: ask my "
182
- "preference for actor\nC: ask my preference for director\nD: ask "
183
- "my preference for writer\nE: I can directly give recommendations"
184
- "\nPlease enter the option character. Please only response a "
185
- "character."
186
- )
187
- options = {
188
- "A": {"attribute": "genre", "template": "What genre do you like?"},
189
- "B": {"attribute": "actor", "template": "Which star do you like?"},
190
- "C": {
191
- "attribute": "director",
192
- "template": "Which director do you like?",
193
- },
194
- "D": {
195
- "attribute": "writer",
196
- "template": "Which writer do you like?",
197
- },
198
- }
199
- return instructions, options
200
-
201
- raise ValueError(f"Dataset {dataset} is not supported.")
 
 
1
+ import json
2
+ import random
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from rapidfuzz import fuzz, process
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ special_tokens_dict = {"pad_token": "<|pad|>"}
11
+
12
+
13
+ def load_jsonl_data(file):
14
+ data_list = []
15
+ with open(file, encoding="utf-8") as f:
16
+ for line in f:
17
+ data = json.loads(line)
18
+ data_list.append(data)
19
+ return data_list
20
+
21
+
22
+ def simple_collate(batch):
23
+ return batch
24
+
25
+
26
+ def sample_data(data_list, shot=1, debug=False, number_for_debug=320):
27
+ if debug:
28
+ data_list = data_list[:number_for_debug]
29
+
30
+ if shot < 1:
31
+ data_idx = random.sample(
32
+ range(len(data_list)), int(len(data_list) * shot)
33
+ )
34
+ data_list = [data_list[idx] for idx in data_idx]
35
+ elif shot > 1:
36
+ data_idx = range(int(shot))
37
+ data_list = [data_list[idx] for idx in data_idx]
38
+
39
+ return data_list
40
+
41
+
42
+ def padded_tensor(
43
+ items: List[Union[List[int], torch.LongTensor]],
44
+ pad_id: int = 0,
45
+ pad_tail: bool = True,
46
+ device: torch.device = torch.device("cpu"),
47
+ debug: bool = False,
48
+ max_length: Optional[int] = None,
49
+ ) -> torch.Tensor:
50
+ # number of items
51
+ n = len(items)
52
+ # length of each item
53
+ lens: List[int] = [len(item) for item in items]
54
+ # max in time dimension
55
+ t = max(max(lens), 1)
56
+ if debug and max_length is not None:
57
+ t = max(t, max_length)
58
+
59
+ output = torch.full(
60
+ (n, t), fill_value=pad_id, dtype=torch.long, device=device
61
+ )
62
+
63
+ for i, (item, length) in enumerate(zip(items, lens)):
64
+ if length == 0:
65
+ continue
66
+ if not isinstance(item, torch.Tensor):
67
+ item = torch.as_tensor(item, dtype=torch.long, device=device)
68
+ if pad_tail:
69
+ output[i, :length] = item
70
+ else:
71
+ output[i, t - length :] = item
72
+
73
+ return output
74
+
75
+
76
+ class SelfAttention(nn.Module):
77
+ def __init__(self, hidden_size):
78
+ super(SelfAttention, self).__init__()
79
+ self.attn = nn.Sequential(
80
+ nn.Linear(hidden_size, hidden_size),
81
+ nn.Tanh(),
82
+ nn.Linear(hidden_size, 1),
83
+ )
84
+
85
+ def forward(self, x, mask=None):
86
+ """
87
+
88
+ Args:
89
+ x (bs, seq_len, hs)
90
+ mask (bs, seq_len): False for masked token.
91
+
92
+ Returns:
93
+ (bs, hs)
94
+ """
95
+ attn = self.attn(x) # (bs, seq_len, 1)
96
+ if mask is not None:
97
+ attn += (~mask).unsqueeze(-1) * -1e4
98
+ attn = F.softmax(attn, dim=-1)
99
+ x = attn.transpose(1, 2) @ x # (bs, 1, hs)
100
+ x = x.squeeze(1)
101
+ return x
102
+
103
+
104
+ def shift_tokens_right(
105
+ input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
106
+ ):
107
+ """
108
+ Shift input ids one token to the right.
109
+ """
110
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
111
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].detach().clone()
112
+ shifted_input_ids[:, 0] = decoder_start_token_id
113
+
114
+ if pad_token_id is None:
115
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
116
+ # replace possible -100 values in labels by `pad_token_id`
117
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
118
+
119
+ return shifted_input_ids
120
+
121
+
122
+ # dbpedia get entity
123
+ # def get_entity(text, SPOTLIGHT_CONFIDENCE):
124
+ # DBPEDIA_SPOTLIGHT_ADDR = " http://0.0.0.0:2222/rest/annotate"
125
+ # headers = {"accept": "application/json"}
126
+ # params = {"text": text, "confidence": SPOTLIGHT_CONFIDENCE}
127
+
128
+ # response = requests.get(DBPEDIA_SPOTLIGHT_ADDR, headers=headers, params=params)
129
+ # response = response.json()
130
+ # return (
131
+ # [f"<{x['@URI']}>" for x in response["Resources"]]
132
+ # if "Resources" in response
133
+ # else []
134
+ # )
135
+
136
+
137
+ # rapidfuzz get entity
138
+ def get_entity(text, entity_list):
139
+ extractions = process.extract(
140
+ text, entity_list, scorer=fuzz.WRatio, limit=20
141
+ )
142
+ extractions = [
143
+ extraction[0] for extraction in extractions if extraction[1] >= 90
144
+ ]
145
+ return extractions
146
+
147
+
148
+ def get_options(dataset: str) -> Tuple[str, Dict[str, str]]:
149
+ """Returns the possible options for a given dataset.
150
+
151
+ Args:
152
+ dataset: The dataset to get options for.
153
+
154
+ Raises:
155
+ ValueError: If the dataset is not supported.
156
+
157
+ Returns:
158
+ A tuple containing the prompt and a dictionary of options.
159
+ """
160
+ if "redial" in dataset:
161
+ instructions = (
162
+ "To recommend me items that I will accept, you can choose one of "
163
+ "the following options.\nA: ask my preference for genre\nB: ask my "
164
+ "preference for actor\nC: ask my preference for director\nD: I can "
165
+ "directly give recommendations\nPlease enter the option character. "
166
+ "Please only response a character."
167
+ )
168
+ options = {
169
+ "A": {"attribute": "genre", "template": "What genre do you like?"},
170
+ "B": {"attribute": "actor", "template": "Which star do you like?"},
171
+ "C": {
172
+ "attribute": "director",
173
+ "template": "Which director do you like?",
174
+ },
175
+ "D": {"attribute": "recommend", "template": ""},
176
+ }
177
+ return instructions, options
178
+ elif "opendialkg" in dataset:
179
+ instructions = (
180
+ "To recommend me items that I will accept, you can choose one of "
181
+ "the following options.\nA: ask my preference for genre\nB: ask my "
182
+ "preference for actor\nC: ask my preference for director\nD: ask "
183
+ "my preference for writer\nE: I can directly give recommendations"
184
+ "\nPlease enter the option character. Please only response a "
185
+ "character."
186
+ )
187
+ options = {
188
+ "A": {"attribute": "genre", "template": "What genre do you like?"},
189
+ "B": {"attribute": "actor", "template": "Which star do you like?"},
190
+ "C": {
191
+ "attribute": "director",
192
+ "template": "Which director do you like?",
193
+ },
194
+ "D": {
195
+ "attribute": "writer",
196
+ "template": "Which writer do you like?",
197
+ },
198
+ "E": {"attribute": "recommend", "template": ""},
199
+ }
200
+ return instructions, options
201
+
202
+ raise ValueError(f"Dataset {dataset} is not supported.")