Spaces:
Build error
Build error
Commit
·
37375ea
1
Parent(s):
da2e323
Upload representation.py
Browse files- representation.py +90 -0
representation.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import AutoModel, AutoTokenizer
|
4 |
+
|
5 |
+
DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
6 |
+
|
7 |
+
|
8 |
+
class TransformerRepresentation(nn.Module):
|
9 |
+
def __init__(self, model_name='bert-base-uncased',
|
10 |
+
transformer_kwargs={'attention_probs_dropout_prob': 0.1,
|
11 |
+
'hidden_dropout_prob': 0.1},
|
12 |
+
device=DEFAULT_DEVICE):
|
13 |
+
super(TransformerRepresentation, self).__init__()
|
14 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
15 |
+
self.model = AutoModel.from_pretrained(model_name,
|
16 |
+
output_hidden_states=True,
|
17 |
+
**transformer_kwargs)
|
18 |
+
self.embedding_dim = self.model.config.hidden_size
|
19 |
+
self.device = device
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def add_subword_maps(texts, encodings):
|
23 |
+
for encoding, t in zip(encodings, texts):
|
24 |
+
encoding.subword_map = [encoding.word_to_tokens(i)
|
25 |
+
for i, _ in enumerate(t)]
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def apply_token_pooling_strategy(outputs, encodings, strategy='first'):
|
29 |
+
"""
|
30 |
+
Applies a token pooling strategy for pretokenized inputs based on
|
31 |
+
a sub-word mapping of words to tokens.
|
32 |
+
|
33 |
+
:param outputs: Output of the application of a `TransformerRepresentation.model` to a pretokenized input.
|
34 |
+
:param encodings: Encodings from the application of `TransformerRepresentation.tokenizer` to a pretokenized input.
|
35 |
+
:param strategy: One of ['first', 'last', 'sum', 'average']. Defaults to 'first'.
|
36 |
+
:return:
|
37 |
+
"""
|
38 |
+
vec_map = [[vecs[m[0]:m[1]] for m in encoding.subword_map
|
39 |
+
if m is not None] # Only return vectors for words that were not truncated during tokenization
|
40 |
+
for vecs, encoding
|
41 |
+
in zip(outputs.last_hidden_state.unbind(), encodings)]
|
42 |
+
if strategy == 'first':
|
43 |
+
return [torch.stack([vec[0] for vec in vm]) if vm else torch.zeros(0) for vm in vec_map]
|
44 |
+
elif strategy == 'last':
|
45 |
+
return [torch.stack([vec[-1] for vec in vm]) if vm else torch.zeros(0) for vm in vec_map]
|
46 |
+
elif strategy == 'sum':
|
47 |
+
return [torch.stack([torch.sum(vec, dim=0) for vec in vm]) if vm else torch.zeros(0) for vm in vec_map]
|
48 |
+
elif strategy == 'average':
|
49 |
+
return [torch.stack([torch.sum(vec, dim=0)/len(vec) for vec in vm]) if vm else torch.zeros(0) for vm in vec_map]
|
50 |
+
return vec_map
|
51 |
+
|
52 |
+
def add_special_tokens(self, tokens):
|
53 |
+
self.tokenizer.add_special_tokens({'additional_special_tokens': self.tokenizer.additional_special_tokens + tokens})
|
54 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
55 |
+
|
56 |
+
def forward(self, text, is_pretokenized=False, add_special_tokens=True, token_pooling='first'):
|
57 |
+
inputs = self.tokenizer(text, padding='longest',
|
58 |
+
is_split_into_words=is_pretokenized,
|
59 |
+
add_special_tokens=add_special_tokens,
|
60 |
+
return_tensors='pt',
|
61 |
+
max_length=512,
|
62 |
+
truncation=True).to(self.device)
|
63 |
+
output = self.model(**inputs.to(self.device))
|
64 |
+
if is_pretokenized:
|
65 |
+
self.add_subword_maps(text, [i for i in inputs.encodings])
|
66 |
+
output.pooled_tokens = self.apply_token_pooling_strategy(
|
67 |
+
output, [i for i in inputs.encodings], strategy=token_pooling)
|
68 |
+
return output
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == 'main':
|
72 |
+
toks = ['Tom', 'Thabane', 'resigned', 'in', 'October', 'last', 'year',
|
73 |
+
'to', 'form', 'the', 'All', 'Basotho', 'Convention', '-LRB-',
|
74 |
+
'ABC', '-RRB-', ',', 'crossing', 'the', 'floor', 'with', '17',
|
75 |
+
'members', 'of', 'parliament', ',', 'causing', 'constitutional',
|
76 |
+
'monarch', 'King', 'Letsie', 'III', 'to', 'dissolve',
|
77 |
+
'parliament', 'and', 'call', 'the', 'snap', 'election', '.']
|
78 |
+
e1_type = 'PERSON'
|
79 |
+
e2_type = 'ORGANIZATION'
|
80 |
+
e1_tokens = [0, 1]
|
81 |
+
e2_tokens = [10, 12]
|
82 |
+
text = [['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'],
|
83 |
+
['Peter', 'Blackburn'],
|
84 |
+
['BRUSSELS', '1996-08-22'],
|
85 |
+
['The', 'European', 'Commission', 'said', 'on', 'Thursday', 'it', 'disagreed', 'with', 'German', 'advice', 'to', 'consumers', 'to', 'shun', 'British', 'lamb', 'until', 'scientists', 'determine', 'whether', 'mad', 'cow', 'disease', 'can', 'be', 'transmitted', 'to', 'sheep', '.'],
|
86 |
+
['Germany', "'s", 'representative', 'to', 'the', 'European', 'Union', "'s", 'veterinary', 'committee', 'Werner', 'Zwingmann', 'said', 'on', 'Wednesday', 'consumers', 'should', 'buy', 'sheepmeat', 'from', 'countries', 'other', 'than', 'Britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.'],
|
87 |
+
['"', 'We', 'do', "n't", 'support', 'any', 'such', 'recommendation', 'because', 'we', 'do', "n't", 'see', 'any', 'grounds', 'for', 'it', ',', '"', 'the', 'Commission', "'s", 'chief', 'spokesman', 'Nikolaus', 'van', 'der', 'Pas', 'told', 'a', 'news', 'briefing', '.'],
|
88 |
+
['He', 'said', 'further', 'scientific', 'study', 'was', 'required', 'and', 'if', 'it', 'was', 'found', 'that', 'action', 'was', 'needed', 'it', 'should', 'be', 'taken', 'by', 'the', 'European', 'Union', '.'],
|
89 |
+
['He', 'said', 'a', 'proposal', 'last', 'month', 'by', 'EU', 'Farm', 'Commissioner', 'Franz', 'Fischler', 'to', 'ban', 'sheep', 'brains', ',', 'spleens', 'and', 'spinal', 'cords', 'from', 'the', 'human', 'and', 'animal', 'food', 'chains', 'was', 'a', 'highly', 'specific', 'and', 'precautionary', 'move', 'to', 'protect', 'human', 'health', '.']]
|
90 |
+
model = TransformerRepresentation()
|