nehalelkaref's picture
Update network.py
ed540bf
raw
history blame
14.6 kB
import copy
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn.functional import cross_entropy, binary_cross_entropy
from tqdm.auto import tqdm
from utils import Config, extract_spans, generate_targets
from representation import TransformerRepresentation
from layers import SpanEnumerationLayer
DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SpanNet(nn.Module):
def __init__(self, **kwargs):
super(SpanNet, self).__init__()
self.config = Config()
self.config.pos = kwargs.get('pos', None) # pos
self.config.dp = kwargs.get('dp', 0.3) # dp
self.config.transformer_model_name = kwargs.get('transformer_model_name', 'bert-base-uncased')
self.config.token_pooling = kwargs.get('token_pooling', 'sum')
self.device = kwargs.get('device', DEFAULT_DEVICE)
self.config.repr_type = kwargs.get('repr_type', 'token_classification')
assert self.config.repr_type in ['token_classification',
'span_enumeration'], 'Invalid representaton type'
self.transformer = TransformerRepresentation(
model_name=self.config.transformer_model_name,
device=self.device).to(self.device)
self.transformer_dim = self.transformer.embedding_dim
if self.config.pos:
self.transformer.add_special_tokens([f'[{p}]' for p in self.config.pos])
self.span_tags = ['B', 'I', 'O'] # , '-']
self.enumeration_layer = SpanEnumerationLayer()
output_size = {'token_classification': len(self.span_tags),
'span_enumeration': 1}
self.span_output_layer = nn.Sequential(
nn.Linear(self.transformer_dim, self.transformer_dim),
nn.ReLU(), nn.Dropout(p=self.config.dp),
nn.Linear(self.transformer_dim, output_size[self.config.repr_type]))
def to_dict(self):
return {
'model_config': self.config.__dict__,
'model_state_dict': self.state_dict()
}
@classmethod
def load_model(cls, model_path, device=DEFAULT_DEVICE):
res = torch.load(model_path, device)
model = cls(**res['model_config'])
model.load_state_dict(res['model_state_dict'], strict=False)
model.eval()
return model
@classmethod
def preds_to_sequences(self, predictions, enumerations, length):
# assumes the function is applied per tensor/sample
# sort descendindly
enum_preds = {predictions[idx].item(): enumerations[idx] for idx in range(len(enumerations))}
sorted_enum_preds = dict(sorted(enum_preds.items(), key=lambda val:val[1], reverse=True))
# look for clashes
spans = [sorted_enum_preds[key] for key in sorted_enum_preds.keys()]
spans_copy = [sorted_enum_preds[key] for key in sorted_enum_preds.keys()]
i=0
while(i!=(len(spans_copy))):
filtered_spans = []
s,e = spans_copy[i]
for j in range(i+1, len(spans_copy)):
sj,ej = spans_copy[j]
if((sj<s<=ej<e) or (sj<s<=ej<=e) or ((s<sj)&(e<ej))):
filtered_spans.append(spans_copy[j])
i+=1
spans_copy = [span for span in spans_copy if span not in filtered_spans]
chosen_indices = [spans.index(span) for span in spans_copy]
filtered_enum_preds = {list(sorted_enum_preds.keys())[idx]:
sorted_enum_preds[list(sorted_enum_preds.keys())[idx]]
for idx in chosen_indices}
# assign BIO to spans
tagged_seq = ['O']*length
for idx in range(len(spans_copy)):
s,e =spans_copy[idx]
tagged_seq[s]='B'
if((e-s)>0):
bounds = (e+1)-(s+1)
tagged_seq[s+1:e+1] =['I'] * bounds
return tagged_seq
def save_model(self, output_path):
torch.save(self.to_dict(), output_path)
def _extract_sentence_vectors(self, sentences, pos=None):
if pos and self.config.pos:
sentences = [[f'[{p}] {s}' for s, p in zip(s, p)]
for s, p in zip(sentences, pos)]
outs = self.transformer(sentences, is_pretokenized=True,
token_pooling=self.config.token_pooling)
return outs.pooled_tokens
def forward(self, sentences, pos=None, tags=None, **kwargs):
out_dict = {}
embs = self._extract_sentence_vectors(sentences, pos)
if kwargs.get('output_word_vecs', False):
out_dict['word_vecs'] = embeddings
lens = [len(s) for s in embs]
if self.config.repr_type == 'span_enumeration':
embs, enumerations = self.enumeration_layer(embs, lens)
lens = [len(e) for e in enumerations]
input_layer = pad_sequence(embs, batch_first=True)
span_scores = [torch.unbind(f)[:l]
for f, l in zip(self.span_output_layer(input_layer), lens)]
if kwargs.get('output_span_scores', False):
out_dict['span_scores'] = span_scores
if self.config.repr_type == "token_classification":
pred_span_ids = [[torch.argmax(s) for s in sc] for sc in span_scores]
pred_span_tags = [[self.span_tags[idx] for idx in sequence]
for sequence in pred_span_ids]
out_dict['pred_tags'] = pred_span_tags
else:
lens = [len(s) for s in sentences]
tagged_seq=[]
prev_enum = 0
for idx in range(0, len(enumerations)):
enum = enumerations[idx]
length =lens[idx]
scores = flat_scores[prev_enum :len(enum)+ prev_enum]
prev_enum = len(enum)
tagged_seq.append(self.preds_to_sequences(scores, enum, length))
out_dict['pred_tags'] = tagged_seq
if tags is None:
return out_dict
if self.config.repr_type == 'span_enumeration':
targets = generate_targets(enumerations, tags)
targets = torch.Tensor([t for st in targets for t in st])
flat_scores = torch.Tensor([t for score in span_scores for t in score])
print('before: ', flat_scores.shape)
if self.config.repr_type == 'token_classification':
# limit the targets of each sentence to the words not truncated during tokenization
targets = torch.cat(
[torch.tensor([self.span_tags.index(t[0]) for t, _ in zip(tg, sc)])
for tg, sc in zip(tags, span_scores)]).to(self.device)
flat_scores = torch.stack([s for tg, sc in zip(tags, span_scores) for _, s in zip(tg, sc)])
if self.config.repr_type == 'span_enumeration':
span_loss = binary_cross_entropy(flat_scores.sigmoid(), targets)
else:
span_loss = cross_entropy(flat_scores, targets)
out_dict['loss'] = span_loss
return out_dict
def from_span_scores(self, span_scores):
pred_span_ids = [[torch.argmax(s) for s in sc] for sc in span_scores]
return [[self.span_tags[idx] for idx in sequence]
for sequence in pred_span_ids]
class EntNet(nn.Module):
def __init__(self, **kwargs):
super(EntNet, self).__init__()
self.config = Config()
self.span_net = kwargs.get('span_net')
self.config.tune_span_net = kwargs.get('tune_span_net', False)
self.config.use_span_emb = kwargs.get('use_span_emb', False)
self.config.use_ent_markers = kwargs.get('use_ent_markers', False)
# it is possible to tune span_net without using its embeddings
if self.span_net and not self.config.tune_span_net:
for p in self.span_net.parameters():
p.requires_grad = False
self.config.ent_tags = self.ent_tags = kwargs.get('ent_tags')
self.config.pos = kwargs.get('pos', None)
self.config.dp = kwargs.get('dp', 0.3)
self.config.transformer_model_name = kwargs.get('transformer_model_name', 'bert-base-uncased')
self.config.token_pooling = kwargs.get('token_pooling', 'first')
self.device = kwargs.get('device', DEFAULT_DEVICE)
self.transformer = TransformerRepresentation(
model_name=self.config.transformer_model_name,
device=self.device).to(self.device)
self.transformer_dim = self.transformer.embedding_dim
self.transformer.add_special_tokens(['[ENT]', '[/ENT]'])
self.transformer.add_special_tokens(['[INFO]', '[/INFO]'])
if self.config.pos:
self.transformer.add_special_tokens(
['['+p+']' for p in self.config.pos])
self.ent_output_layer = nn.Sequential(
nn.Linear(2*self.transformer_dim, 2*self.transformer_dim),
nn.ReLU(), nn.Dropout(p=self.config.dp),
nn.Linear(2*self.transformer_dim, len(self.config.ent_tags)))
def to_dict(self):
return {
'model_config': self.config.__dict__,
'span_net_config': self.span_net.config.__dict__ if self.span_net is not None else None,
'model_state_dict': self.state_dict()
}
@classmethod
def load_model(cls, model_path, device=DEFAULT_DEVICE):
res = torch.load(model_path, device)
span_net = SpanNet(**res['span_net_config']) if res['span_net_config'] is not None else None
model = cls(span_net=span_net, **res['model_config'])
model.load_state_dict(res['model_state_dict'])
model.eval()
return model
def save_model(self, output_path):
torch.save(self.to_dict(), output_path)
def _extract_sentence_vectors(self, sentences, pos=None, ent_bounds=None):
if pos and self.config.pos:
sentences = [[f'[{p}] {s}' for s, p in zip(s, p)]
for s, p in zip(sentences, pos)]
if ent_bounds and self.config.use_ent_markers:
for sent, sent_ents in zip(sentences, ent_bounds):
for ent in sent_ents:
sent[ent[0]] = f'[ENT] {sent[ent[0]]}'
sent[ent[1]] = f'{sent[ent[1]]} [/ENT]'
outs = self.transformer(sentences, is_pretokenized=True,
token_pooling=self.config.token_pooling)
return outs.pooled_tokens
def forward(self, sentences, pos=None, tags=None, **kwargs):
out_dict = {}
pred_span_seqs = kwargs.get('pred_tags', None)
if pred_span_seqs is None:
span_out = self.span_net(sentences, pos=pos,
output_word_vecs=self.config.use_span_emb,
tags=tags if self.config.tune_span_net else None)
pred_span_seqs = span_out['pred_tags']
bounds = [[e[1] for e in extract_spans(t, tagless=True)[3]]
for t in pred_span_seqs]
if tags is not None:
gold_spans = [[e for e in extract_spans(t, tagless=True)[3]]
for t in tags]
matches = [[[g[0]
for g in golds if p[0] == g[1][0] and p[1] == g[1][1]]
for p in preds]
for preds, golds in zip(bounds, gold_spans)]
targets = [[span_matches[0] if len(span_matches) == 1 else 'O'
for span_matches in sent_matches]
for sent_matches in matches]
sentences = [sent + [t for bd in sent_bounds
for t in [self.transformer.tokenizer.sep_token] + sent[bd[0]:bd[1] + 1]]
+ [self.transformer.tokenizer.sep_token]
for sent, sent_bounds in zip(sentences, bounds)]
sep_ids = [[i for i, s in enumerate(sent) if s == self.transformer.tokenizer.sep_token]
for sent in sentences]
embs = self._extract_sentence_vectors(sentences, pos, bounds)
if kwargs.get('output_word_vecs', False):
out_dict['word_vecs'] = embs
span_vecs = [
torch.stack([torch.cat((torch.sum(e[b[0]:b[1] + 1], dim=0),
torch.sum(e[spi[i]:spi[i+1]+1], dim=0))) for i, b in enumerate(bd)])
if bd else torch.zeros((0)).to(self.device)
for e, bd, spi in zip(embs, bounds, sep_ids)]
ent_scores = [self.ent_output_layer(sv) if len(sv) else sv
for sv in span_vecs]
if kwargs.get('output_ent_scores', False):
out_dict['ent_scores'] = ent_scores
out_dict['bounds'] = bounds
if tags is None:
max_tags = [[self.ent_tags[torch.argmax(e)] for e in es]
for es in ent_scores]
# reconstruct sequences
sent_lens = [len(s) for s in sentences]
combined_sequences = []
for mt, bnd, lens in zip(max_tags, bounds, sent_lens):
x = ['O' for _ in range(lens)]
for t, b in zip(mt, bnd):
x[b[0]] = 'O' if t == 'O' else f'B-{t}'
for i in range(b[0] + 1, b[1] + 1):
x[i] = 'O' if t == 'O' else f'I-{t}'
combined_sequences.append(x)
out_dict['pred_tags'] = combined_sequences
return out_dict
ent_targs = torch.tensor([self.ent_tags.index(t)
for targ in targets for t in targ],
dtype=torch.long).to(self.device)
ent_preds = torch.cat(ent_scores)
if not len(ent_preds):
out_dict['loss'] = None
return out_dict
ent_loss = cross_entropy(ent_preds, ent_targs)
out_dict['loss'] = ent_loss
if self.config.tune_span_net:
out_dict['loss'] += span_out['loss']
return out_dict
def from_ent_scores(self, ent_scores, sentences, bounds):
max_tags = [[self.ent_tags[torch.argmax(e)] for e in es]
for es in ent_scores]
# reconstruct sequences
sent_lens = [len(s) for s in sentences]
combined_sequences = []
for mt, bnd, lens in zip(max_tags, bounds, sent_lens):
x = ['O' for _ in range(lens)]
for t, b in zip(mt, bnd):
x[b[0]] = 'O' if t == 'O' else f'B-{t}'
for i in range(b[0] + 1, b[1] + 1):
x[i] = 'O' if t == 'O' else f'I-{t}'
combined_sequences.append(x)
return combined_sequences