nehalelkaref commited on
Commit
22bbb24
·
1 Parent(s): 40d48bd

Delete network.py

Browse files
Files changed (1) hide show
  1. network.py +0 -333
network.py DELETED
@@ -1,333 +0,0 @@
1
- import copy
2
-
3
- import torch
4
- import torch.nn as nn
5
- from torch.nn.utils.rnn import pad_sequence
6
- from torch.nn.functional import cross_entropy, binary_cross_entropy
7
- from tqdm.auto import tqdm
8
-
9
- from utils import Config, extract_spans, generate_targets
10
- from representation import TransformerRepresentation
11
- from layers import SpanEnumerationLayer
12
-
13
- DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
-
15
-
16
- class SpanNet(nn.Module):
17
- def __init__(self, **kwargs):
18
- super(SpanNet, self).__init__()
19
- self.config = Config()
20
- self.config.pos = kwargs.get('pos', None) # pos
21
- self.config.dp = kwargs.get('dp', 0.3) # dp
22
- self.config.transformer_model_name = kwargs.get('transformer_model_name', 'bert-base-uncased')
23
- self.config.token_pooling = kwargs.get('token_pooling', 'sum')
24
- self.device = kwargs.get('device', DEFAULT_DEVICE)
25
-
26
- self.config.repr_type = kwargs.get('repr_type', 'token_classification')
27
- assert self.config.repr_type in ['token_classification',
28
- 'span_enumeration'], 'Invalid representaton type'
29
-
30
-
31
- self.transformer = TransformerRepresentation(
32
- model_name=self.config.transformer_model_name,
33
- device=self.device).to(self.device)
34
-
35
- self.transformer_dim = self.transformer.embedding_dim
36
- if self.config.pos:
37
- self.transformer.add_special_tokens([f'[{p}]' for p in self.config.pos])
38
-
39
- self.span_tags = ['B', 'I', 'O'] # , '-']
40
-
41
- self.enumeration_layer = SpanEnumerationLayer()
42
- output_size = {'token_classification': len(self.span_tags),
43
- 'span_enumeration': 1}
44
- self.span_output_layer = nn.Sequential(
45
- nn.Linear(self.transformer_dim, self.transformer_dim),
46
- nn.ReLU(), nn.Dropout(p=self.config.dp),
47
- nn.Linear(self.transformer_dim, output_size[self.config.repr_type]))
48
- def to_dict(self):
49
- return {
50
- 'model_config': self.config.__dict__,
51
- 'model_state_dict': self.state_dict()
52
- }
53
-
54
- @classmethod
55
- def load_model(cls, model_path, device=DEFAULT_DEVICE):
56
- res = torch.load(model_path, device)
57
- model = cls(**res['model_config'])
58
- model.load_state_dict(res['model_state_dict'], strict=False)
59
- model.eval()
60
- return model
61
-
62
- @classmethod
63
- def preds_to_sequences(self, predictions, enumerations, length):
64
- # assumes the function is applied per tensor/sample
65
- # sort descendindly
66
- enum_preds = {predictions[idx].item(): enumerations[idx] for idx in range(len(enumerations))}
67
- sorted_enum_preds = dict(sorted(enum_preds.items(), key=lambda val:val[1], reverse=True))
68
-
69
- # look for clashes
70
- spans = [sorted_enum_preds[key] for key in sorted_enum_preds.keys()]
71
- spans_copy = [sorted_enum_preds[key] for key in sorted_enum_preds.keys()]
72
-
73
- i=0
74
- while(i!=(len(spans_copy))):
75
- filtered_spans = []
76
- s,e = spans_copy[i]
77
- for j in range(i+1, len(spans_copy)):
78
- sj,ej = spans_copy[j]
79
- if((sj<s<=ej<e) or (sj<s<=ej<=e) or ((s<sj)&(e<ej))):
80
- filtered_spans.append(spans_copy[j])
81
- i+=1
82
- spans_copy = [span for span in spans_copy if span not in filtered_spans]
83
-
84
- chosen_indices = [spans.index(span) for span in spans_copy]
85
- filtered_enum_preds = {list(sorted_enum_preds.keys())[idx]:
86
- sorted_enum_preds[list(sorted_enum_preds.keys())[idx]]
87
- for idx in chosen_indices}
88
- # assign BIO to spans
89
- tagged_seq = ['O']*length
90
- for idx in range(len(spans_copy)):
91
- s,e =spans_copy[idx]
92
-
93
- tagged_seq[s]='B'
94
-
95
- if((e-s)>0):
96
- bounds = (e+1)-(s+1)
97
- tagged_seq[s+1:e+1] =['I'] * bounds
98
-
99
- return tagged_seq
100
-
101
- def save_model(self, output_path):
102
- torch.save(self.to_dict(), output_path)
103
-
104
- def _extract_sentence_vectors(self, sentences, pos=None):
105
- if pos and self.config.pos:
106
- sentences = [[f'[{p}] {s}' for s, p in zip(s, p)]
107
- for s, p in zip(sentences, pos)]
108
- outs = self.transformer(sentences, is_pretokenized=True,
109
- token_pooling=self.config.token_pooling)
110
- return outs.pooled_tokens
111
-
112
- def forward(self, sentences, pos=None, tags=None, **kwargs):
113
- out_dict = {}
114
- embs = self._extract_sentence_vectors(sentences, pos)
115
- if kwargs.get('output_word_vecs', False):
116
- out_dict['word_vecs'] = embeddings
117
-
118
- lens = [len(s) for s in embs]
119
-
120
- if self.config.repr_type == 'span_enumeration':
121
- embs, enumerations = self.enumeration_layer(embs, lens)
122
- lens = [len(e) for e in enumerations]
123
-
124
- input_layer = pad_sequence(embs, batch_first=True)
125
-
126
- span_scores = [torch.unbind(f)[:l]
127
- for f, l in zip(self.span_output_layer(input_layer), lens)]
128
-
129
-
130
- if kwargs.get('output_span_scores', False):
131
- out_dict['span_scores'] = span_scores
132
- if self.config.repr_type == "token_classification":
133
- pred_span_ids = [[torch.argmax(s) for s in sc] for sc in span_scores]
134
- pred_span_tags = [[self.span_tags[idx] for idx in sequence]
135
- for sequence in pred_span_ids]
136
- out_dict['pred_tags'] = pred_span_tags
137
- else:
138
- lens = [len(s) for s in sentences]
139
- tagged_seq=[]
140
- prev_enum = 0
141
- for idx in range(0, len(enumerations)):
142
- enum = enumerations[idx]
143
- length =lens[idx]
144
-
145
- scores = flat_scores[prev_enum :len(enum)+ prev_enum]
146
-
147
- prev_enum = len(enum)
148
- tagged_seq.append(self.preds_to_sequences(scores, enum, length))
149
- out_dict['pred_tags'] = tagged_seq
150
-
151
-
152
- if tags is None:
153
- return out_dict
154
-
155
- if self.config.repr_type == 'span_enumeration':
156
- targets = generate_targets(enumerations, tags)
157
- targets = torch.Tensor([t for st in targets for t in st])
158
- flat_scores = torch.Tensor([t for score in span_scores for t in score])
159
- print('before: ', flat_scores.shape)
160
- if self.config.repr_type == 'token_classification':
161
- # limit the targets of each sentence to the words not truncated during tokenization
162
- targets = torch.cat(
163
- [torch.tensor([self.span_tags.index(t[0]) for t, _ in zip(tg, sc)])
164
- for tg, sc in zip(tags, span_scores)]).to(self.device)
165
- flat_scores = torch.stack([s for tg, sc in zip(tags, span_scores) for _, s in zip(tg, sc)])
166
-
167
-
168
- if self.config.repr_type == 'span_enumeration':
169
- span_loss = binary_cross_entropy(flat_scores.sigmoid(), targets)
170
-
171
- else:
172
- span_loss = cross_entropy(flat_scores, targets)
173
- out_dict['loss'] = span_loss
174
- return out_dict
175
-
176
- def from_span_scores(self, span_scores):
177
- pred_span_ids = [[torch.argmax(s) for s in sc] for sc in span_scores]
178
- return [[self.span_tags[idx] for idx in sequence]
179
- for sequence in pred_span_ids]
180
-
181
-
182
- class EntNet(nn.Module):
183
- def __init__(self, **kwargs):
184
- super(EntNet, self).__init__()
185
- self.config = Config()
186
- self.span_net = kwargs.get('span_net')
187
- self.config.tune_span_net = kwargs.get('tune_span_net', False)
188
- self.config.use_span_emb = kwargs.get('use_span_emb', False)
189
- self.config.use_ent_markers = kwargs.get('use_ent_markers', False)
190
- # it is possible to tune span_net without using its embeddings
191
- if self.span_net and not self.config.tune_span_net:
192
- for p in self.span_net.parameters():
193
- p.requires_grad = False
194
- self.config.ent_tags = self.ent_tags = kwargs.get('ent_tags')
195
- self.config.pos = kwargs.get('pos', None)
196
- self.config.dp = kwargs.get('dp', 0.3)
197
- self.config.transformer_model_name = kwargs.get('transformer_model_name', 'bert-base-uncased')
198
- self.config.token_pooling = kwargs.get('token_pooling', 'first')
199
- self.device = kwargs.get('device', DEFAULT_DEVICE)
200
-
201
- self.transformer = TransformerRepresentation(
202
- model_name=self.config.transformer_model_name,
203
- device=self.device).to(self.device)
204
- self.transformer_dim = self.transformer.embedding_dim
205
-
206
- self.transformer.add_special_tokens(['[ENT]', '[/ENT]'])
207
- self.transformer.add_special_tokens(['[INFO]', '[/INFO]'])
208
- if self.config.pos:
209
- self.transformer.add_special_tokens(
210
- ['['+p+']' for p in self.config.pos])
211
-
212
- self.ent_output_layer = nn.Sequential(
213
- nn.Linear(2*self.transformer_dim, 2*self.transformer_dim),
214
- nn.ReLU(), nn.Dropout(p=self.config.dp),
215
- nn.Linear(2*self.transformer_dim, len(self.config.ent_tags)))
216
-
217
- def to_dict(self):
218
- return {
219
- 'model_config': self.config.__dict__,
220
- 'span_net_config': self.span_net.config.__dict__ if self.span_net is not None else None,
221
- 'model_state_dict': self.state_dict()
222
- }
223
-
224
- @classmethod
225
- def load_model(cls, model_path, device=DEFAULT_DEVICE):
226
- res = torch.load(model_path, device)
227
- span_net = SpanNet(**res['span_net_config']) if res['span_net_config'] is not None else None
228
- model = cls(span_net=span_net, **res['model_config'])
229
- model.load_state_dict(res['model_state_dict'])
230
- model.eval()
231
- return model
232
-
233
- def save_model(self, output_path):
234
- torch.save(self.to_dict(), output_path)
235
-
236
- def _extract_sentence_vectors(self, sentences, pos=None, ent_bounds=None):
237
- if pos and self.config.pos:
238
- sentences = [[f'[{p}] {s}' for s, p in zip(s, p)]
239
- for s, p in zip(sentences, pos)]
240
- if ent_bounds and self.config.use_ent_markers:
241
- for sent, sent_ents in zip(sentences, ent_bounds):
242
- for ent in sent_ents:
243
- sent[ent[0]] = f'[ENT] {sent[ent[0]]}'
244
- sent[ent[1]] = f'{sent[ent[1]]} [/ENT]'
245
-
246
- outs = self.transformer(sentences, is_pretokenized=True,
247
- token_pooling=self.config.token_pooling)
248
- return outs.pooled_tokens
249
-
250
- def forward(self, sentences, pos=None, tags=None, **kwargs):
251
- out_dict = {}
252
- pred_span_seqs = kwargs.get('pred_tags', None)
253
- if pred_span_seqs is None:
254
- span_out = self.span_net(sentences, pos=pos,
255
- output_word_vecs=self.config.use_span_emb,
256
- tags=tags if self.config.tune_span_net else None)
257
- pred_span_seqs = span_out['pred_tags']
258
- bounds = [[e[1] for e in extract_spans(t, tagless=True)[3]]
259
- for t in pred_span_seqs]
260
- if tags is not None:
261
- gold_spans = [[e for e in extract_spans(t, tagless=True)[3]]
262
- for t in tags]
263
- matches = [[[g[0]
264
- for g in golds if p[0] == g[1][0] and p[1] == g[1][1]]
265
- for p in preds]
266
- for preds, golds in zip(bounds, gold_spans)]
267
- targets = [[span_matches[0] if len(span_matches) == 1 else 'O'
268
- for span_matches in sent_matches]
269
- for sent_matches in matches]
270
-
271
- sentences = [sent + [t for bd in sent_bounds
272
- for t in [self.transformer.tokenizer.sep_token] + sent[bd[0]:bd[1] + 1]]
273
- + [self.transformer.tokenizer.sep_token]
274
- for sent, sent_bounds in zip(sentences, bounds)]
275
- sep_ids = [[i for i, s in enumerate(sent) if s == self.transformer.tokenizer.sep_token]
276
- for sent in sentences]
277
- embs = self._extract_sentence_vectors(sentences, pos, bounds)
278
- if kwargs.get('output_word_vecs', False):
279
- out_dict['word_vecs'] = embs
280
-
281
- span_vecs = [
282
- torch.stack([torch.cat((torch.sum(e[b[0]:b[1] + 1], dim=0),
283
- torch.sum(e[spi[i]:spi[i+1]+1], dim=0))) for i, b in enumerate(bd)])
284
- if bd else torch.zeros((0)).to(self.device)
285
- for e, bd, spi in zip(embs, bounds, sep_ids)]
286
- ent_scores = [self.ent_output_layer(sv) if len(sv) else sv
287
- for sv in span_vecs]
288
- if kwargs.get('output_ent_scores', False):
289
- out_dict['ent_scores'] = ent_scores
290
- out_dict['bounds'] = bounds
291
- if tags is None:
292
- max_tags = [[self.ent_tags[torch.argmax(e)] for e in es]
293
- for es in ent_scores]
294
- # reconstruct sequences
295
- sent_lens = [len(s) for s in sentences]
296
- combined_sequences = []
297
- for mt, bnd, lens in zip(max_tags, bounds, sent_lens):
298
- x = ['O' for _ in range(lens)]
299
- for t, b in zip(mt, bnd):
300
- x[b[0]] = 'O' if t == 'O' else f'B-{t}'
301
- for i in range(b[0] + 1, b[1] + 1):
302
- x[i] = 'O' if t == 'O' else f'I-{t}'
303
- combined_sequences.append(x)
304
- out_dict['pred_tags'] = combined_sequences
305
- return out_dict
306
-
307
- ent_targs = torch.tensor([self.ent_tags.index(t)
308
- for targ in targets for t in targ],
309
- dtype=torch.long).to(self.device)
310
- ent_preds = torch.cat(ent_scores)
311
- if not len(ent_preds):
312
- out_dict['loss'] = None
313
- return out_dict
314
- ent_loss = cross_entropy(ent_preds, ent_targs)
315
- out_dict['loss'] = ent_loss
316
- if self.config.tune_span_net:
317
- out_dict['loss'] += span_out['loss']
318
- return out_dict
319
-
320
- def from_ent_scores(self, ent_scores, sentences, bounds):
321
- max_tags = [[self.ent_tags[torch.argmax(e)] for e in es]
322
- for es in ent_scores]
323
- # reconstruct sequences
324
- sent_lens = [len(s) for s in sentences]
325
- combined_sequences = []
326
- for mt, bnd, lens in zip(max_tags, bounds, sent_lens):
327
- x = ['O' for _ in range(lens)]
328
- for t, b in zip(mt, bnd):
329
- x[b[0]] = 'O' if t == 'O' else f'B-{t}'
330
- for i in range(b[0] + 1, b[1] + 1):
331
- x[i] = 'O' if t == 'O' else f'I-{t}'
332
- combined_sequences.append(x)
333
- return combined_sequences