Spaces:
Build error
Build error
import copy | |
import torch | |
import torch.nn as nn | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.nn.functional import cross_entropy, binary_cross_entropy | |
from tqdm.auto import tqdm | |
from utils import Config, extract_spans, generate_targets | |
from representation import TransformerRepresentation | |
from layers import SpanEnumerationLayer | |
DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class SpanNet(nn.Module): | |
def __init__(self, **kwargs): | |
super(SpanNet, self).__init__() | |
self.config = Config() | |
self.config.pos = kwargs.get('pos', None) # pos | |
self.config.dp = kwargs.get('dp', 0.3) # dp | |
self.config.transformer_model_name = kwargs.get('transformer_model_name', 'bert-base-uncased') | |
self.config.token_pooling = kwargs.get('token_pooling', 'sum') | |
self.device = kwargs.get('device', DEFAULT_DEVICE) | |
self.config.repr_type = kwargs.get('repr_type', 'token_classification') | |
assert self.config.repr_type in ['token_classification', | |
'span_enumeration'], 'Invalid representaton type' | |
self.transformer = TransformerRepresentation( | |
model_name=self.config.transformer_model_name, | |
device=self.device).to(self.device) | |
self.transformer_dim = self.transformer.embedding_dim | |
if self.config.pos: | |
self.transformer.add_special_tokens([f'[{p}]' for p in self.config.pos]) | |
self.span_tags = ['B', 'I', 'O'] # , '-'] | |
self.enumeration_layer = SpanEnumerationLayer() | |
output_size = {'token_classification': len(self.span_tags), | |
'span_enumeration': 1} | |
self.span_output_layer = nn.Sequential( | |
nn.Linear(self.transformer_dim, self.transformer_dim), | |
nn.ReLU(), nn.Dropout(p=self.config.dp), | |
nn.Linear(self.transformer_dim, output_size[self.config.repr_type])) | |
def to_dict(self): | |
return { | |
'model_config': self.config.__dict__, | |
'model_state_dict': self.state_dict() | |
} | |
def load_model(cls, model_path, device=DEFAULT_DEVICE): | |
res = torch.load(model_path, device) | |
model = cls(**res['model_config']) | |
model.load_state_dict(res['model_state_dict'], strict=False) | |
model.eval() | |
return model | |
def preds_to_sequences(self, predictions, enumerations, length): | |
# assumes the function is applied per tensor/sample | |
# sort descendindly | |
enum_preds = {predictions[idx].item(): enumerations[idx] for idx in range(len(enumerations))} | |
sorted_enum_preds = dict(sorted(enum_preds.items(), key=lambda val:val[1], reverse=True)) | |
# look for clashes | |
spans = [sorted_enum_preds[key] for key in sorted_enum_preds.keys()] | |
spans_copy = [sorted_enum_preds[key] for key in sorted_enum_preds.keys()] | |
i=0 | |
while(i!=(len(spans_copy))): | |
filtered_spans = [] | |
s,e = spans_copy[i] | |
for j in range(i+1, len(spans_copy)): | |
sj,ej = spans_copy[j] | |
if((sj<s<=ej<e) or (sj<s<=ej<=e) or ((s<sj)&(e<ej))): | |
filtered_spans.append(spans_copy[j]) | |
i+=1 | |
spans_copy = [span for span in spans_copy if span not in filtered_spans] | |
chosen_indices = [spans.index(span) for span in spans_copy] | |
filtered_enum_preds = {list(sorted_enum_preds.keys())[idx]: | |
sorted_enum_preds[list(sorted_enum_preds.keys())[idx]] | |
for idx in chosen_indices} | |
# assign BIO to spans | |
tagged_seq = ['O']*length | |
for idx in range(len(spans_copy)): | |
s,e =spans_copy[idx] | |
tagged_seq[s]='B' | |
if((e-s)>0): | |
bounds = (e+1)-(s+1) | |
tagged_seq[s+1:e+1] =['I'] * bounds | |
return tagged_seq | |
def save_model(self, output_path): | |
torch.save(self.to_dict(), output_path) | |
def _extract_sentence_vectors(self, sentences, pos=None): | |
if pos and self.config.pos: | |
sentences = [[f'[{p}] {s}' for s, p in zip(s, p)] | |
for s, p in zip(sentences, pos)] | |
outs = self.transformer(sentences, is_pretokenized=True, | |
token_pooling=self.config.token_pooling) | |
return outs.pooled_tokens | |
def forward(self, sentences, pos=None, tags=None, **kwargs): | |
out_dict = {} | |
embs = self._extract_sentence_vectors(sentences, pos) | |
if kwargs.get('output_word_vecs', False): | |
out_dict['word_vecs'] = embeddings | |
lens = [len(s) for s in embs] | |
if self.config.repr_type == 'span_enumeration': | |
embs, enumerations = self.enumeration_layer(embs, lens) | |
lens = [len(e) for e in enumerations] | |
input_layer = pad_sequence(embs, batch_first=True) | |
span_scores = [torch.unbind(f)[:l] | |
for f, l in zip(self.span_output_layer(input_layer), lens)] | |
if kwargs.get('output_span_scores', False): | |
out_dict['span_scores'] = span_scores | |
if self.config.repr_type == "token_classification": | |
pred_span_ids = [[torch.argmax(s) for s in sc] for sc in span_scores] | |
pred_span_tags = [[self.span_tags[idx] for idx in sequence] | |
for sequence in pred_span_ids] | |
out_dict['pred_tags'] = pred_span_tags | |
else: | |
lens = [len(s) for s in sentences] | |
tagged_seq=[] | |
prev_enum = 0 | |
for idx in range(0, len(enumerations)): | |
enum = enumerations[idx] | |
length =lens[idx] | |
scores = flat_scores[prev_enum :len(enum)+ prev_enum] | |
prev_enum = len(enum) | |
tagged_seq.append(self.preds_to_sequences(scores, enum, length)) | |
out_dict['pred_tags'] = tagged_seq | |
if tags is None: | |
return out_dict | |
if self.config.repr_type == 'span_enumeration': | |
targets = generate_targets(enumerations, tags) | |
targets = torch.Tensor([t for st in targets for t in st]) | |
flat_scores = torch.Tensor([t for score in span_scores for t in score]) | |
print('before: ', flat_scores.shape) | |
if self.config.repr_type == 'token_classification': | |
# limit the targets of each sentence to the words not truncated during tokenization | |
targets = torch.cat( | |
[torch.tensor([self.span_tags.index(t[0]) for t, _ in zip(tg, sc)]) | |
for tg, sc in zip(tags, span_scores)]).to(self.device) | |
flat_scores = torch.stack([s for tg, sc in zip(tags, span_scores) for _, s in zip(tg, sc)]) | |
if self.config.repr_type == 'span_enumeration': | |
span_loss = binary_cross_entropy(flat_scores.sigmoid(), targets) | |
else: | |
span_loss = cross_entropy(flat_scores, targets) | |
out_dict['loss'] = span_loss | |
return out_dict | |
def from_span_scores(self, span_scores): | |
pred_span_ids = [[torch.argmax(s) for s in sc] for sc in span_scores] | |
return [[self.span_tags[idx] for idx in sequence] | |
for sequence in pred_span_ids] | |
class EntNet(nn.Module): | |
def __init__(self, **kwargs): | |
super(EntNet, self).__init__() | |
self.config = Config() | |
self.span_net = kwargs.get('span_net') | |
self.config.tune_span_net = kwargs.get('tune_span_net', False) | |
self.config.use_span_emb = kwargs.get('use_span_emb', False) | |
self.config.use_ent_markers = kwargs.get('use_ent_markers', False) | |
# it is possible to tune span_net without using its embeddings | |
if self.span_net and not self.config.tune_span_net: | |
for p in self.span_net.parameters(): | |
p.requires_grad = False | |
self.config.ent_tags = self.ent_tags = kwargs.get('ent_tags') | |
self.config.pos = kwargs.get('pos', None) | |
self.config.dp = kwargs.get('dp', 0.3) | |
self.config.transformer_model_name = kwargs.get('transformer_model_name', 'bert-base-uncased') | |
self.config.token_pooling = kwargs.get('token_pooling', 'first') | |
self.device = kwargs.get('device', DEFAULT_DEVICE) | |
self.transformer = TransformerRepresentation( | |
model_name=self.config.transformer_model_name, | |
device=self.device).to(self.device) | |
self.transformer_dim = self.transformer.embedding_dim | |
self.transformer.add_special_tokens(['[ENT]', '[/ENT]']) | |
self.transformer.add_special_tokens(['[INFO]', '[/INFO]']) | |
if self.config.pos: | |
self.transformer.add_special_tokens( | |
['['+p+']' for p in self.config.pos]) | |
self.ent_output_layer = nn.Sequential( | |
nn.Linear(2*self.transformer_dim, 2*self.transformer_dim), | |
nn.ReLU(), nn.Dropout(p=self.config.dp), | |
nn.Linear(2*self.transformer_dim, len(self.config.ent_tags))) | |
def to_dict(self): | |
return { | |
'model_config': self.config.__dict__, | |
'span_net_config': self.span_net.config.__dict__ if self.span_net is not None else None, | |
'model_state_dict': self.state_dict() | |
} | |
def load_model(cls, model_path, device=DEFAULT_DEVICE): | |
res = torch.load(model_path, device) | |
span_net = SpanNet(**res['span_net_config']) if res['span_net_config'] is not None else None | |
model = cls(span_net=span_net, **res['model_config']) | |
model.load_state_dict(res['model_state_dict']) | |
model.eval() | |
return model | |
def save_model(self, output_path): | |
torch.save(self.to_dict(), output_path) | |
def _extract_sentence_vectors(self, sentences, pos=None, ent_bounds=None): | |
if pos and self.config.pos: | |
sentences = [[f'[{p}] {s}' for s, p in zip(s, p)] | |
for s, p in zip(sentences, pos)] | |
if ent_bounds and self.config.use_ent_markers: | |
for sent, sent_ents in zip(sentences, ent_bounds): | |
for ent in sent_ents: | |
sent[ent[0]] = f'[ENT] {sent[ent[0]]}' | |
sent[ent[1]] = f'{sent[ent[1]]} [/ENT]' | |
outs = self.transformer(sentences, is_pretokenized=True, | |
token_pooling=self.config.token_pooling) | |
return outs.pooled_tokens | |
def forward(self, sentences, pos=None, tags=None, **kwargs): | |
out_dict = {} | |
pred_span_seqs = kwargs.get('pred_tags', None) | |
if pred_span_seqs is None: | |
span_out = self.span_net(sentences, pos=pos, | |
output_word_vecs=self.config.use_span_emb, | |
tags=tags if self.config.tune_span_net else None) | |
pred_span_seqs = span_out['pred_tags'] | |
bounds = [[e[1] for e in extract_spans(t, tagless=True)[3]] | |
for t in pred_span_seqs] | |
if tags is not None: | |
gold_spans = [[e for e in extract_spans(t, tagless=True)[3]] | |
for t in tags] | |
matches = [[[g[0] | |
for g in golds if p[0] == g[1][0] and p[1] == g[1][1]] | |
for p in preds] | |
for preds, golds in zip(bounds, gold_spans)] | |
targets = [[span_matches[0] if len(span_matches) == 1 else 'O' | |
for span_matches in sent_matches] | |
for sent_matches in matches] | |
sentences = [sent + [t for bd in sent_bounds | |
for t in [self.transformer.tokenizer.sep_token] + sent[bd[0]:bd[1] + 1]] | |
+ [self.transformer.tokenizer.sep_token] | |
for sent, sent_bounds in zip(sentences, bounds)] | |
sep_ids = [[i for i, s in enumerate(sent) if s == self.transformer.tokenizer.sep_token] | |
for sent in sentences] | |
embs = self._extract_sentence_vectors(sentences, pos, bounds) | |
if kwargs.get('output_word_vecs', False): | |
out_dict['word_vecs'] = embs | |
span_vecs = [ | |
torch.stack([torch.cat((torch.sum(e[b[0]:b[1] + 1], dim=0), | |
torch.sum(e[spi[i]:spi[i+1]+1], dim=0))) for i, b in enumerate(bd)]) | |
if bd else torch.zeros((0)).to(self.device) | |
for e, bd, spi in zip(embs, bounds, sep_ids)] | |
ent_scores = [self.ent_output_layer(sv) if len(sv) else sv | |
for sv in span_vecs] | |
if kwargs.get('output_ent_scores', False): | |
out_dict['ent_scores'] = ent_scores | |
out_dict['bounds'] = bounds | |
if tags is None: | |
max_tags = [[self.ent_tags[torch.argmax(e)] for e in es] | |
for es in ent_scores] | |
# reconstruct sequences | |
sent_lens = [len(s) for s in sentences] | |
combined_sequences = [] | |
for mt, bnd, lens in zip(max_tags, bounds, sent_lens): | |
x = ['O' for _ in range(lens)] | |
for t, b in zip(mt, bnd): | |
x[b[0]] = 'O' if t == 'O' else f'B-{t}' | |
for i in range(b[0] + 1, b[1] + 1): | |
x[i] = 'O' if t == 'O' else f'I-{t}' | |
combined_sequences.append(x) | |
out_dict['pred_tags'] = combined_sequences | |
return out_dict | |
ent_targs = torch.tensor([self.ent_tags.index(t) | |
for targ in targets for t in targ], | |
dtype=torch.long).to(self.device) | |
ent_preds = torch.cat(ent_scores) | |
if not len(ent_preds): | |
out_dict['loss'] = None | |
return out_dict | |
ent_loss = cross_entropy(ent_preds, ent_targs) | |
out_dict['loss'] = ent_loss | |
if self.config.tune_span_net: | |
out_dict['loss'] += span_out['loss'] | |
return out_dict | |
def from_ent_scores(self, ent_scores, sentences, bounds): | |
max_tags = [[self.ent_tags[torch.argmax(e)] for e in es] | |
for es in ent_scores] | |
# reconstruct sequences | |
sent_lens = [len(s) for s in sentences] | |
combined_sequences = [] | |
for mt, bnd, lens in zip(max_tags, bounds, sent_lens): | |
x = ['O' for _ in range(lens)] | |
for t, b in zip(mt, bnd): | |
x[b[0]] = 'O' if t == 'O' else f'B-{t}' | |
for i in range(b[0] + 1, b[1] + 1): | |
x[i] = 'O' if t == 'O' else f'I-{t}' | |
combined_sequences.append(x) | |
return combined_sequences | |