File size: 3,004 Bytes
7d95c60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
from torch import nn
from transformers import AutoModel
class Encoder(nn.Module):
def __init__(self,
model_checkpoint,
representation='cls',
fixed=False):
super(Encoder, self).__init__()
self.encoder = AutoModel.from_pretrained(model_checkpoint)
self.representation = representation
self.fixed = fixed
def get_representation(self,
input_ids,
attention_mask,
token_type_ids=None):
output = None
if input_ids is not None:
if self.fixed:
with torch.no_grad():
outputs = self.encoder(input_ids,
attention_mask,
token_type_ids)
else:
outputs = self.encoder(input_ids,
attention_mask,
token_type_ids)
sequence_output = outputs['last_hidden_state']
#sequence_output = sequence_output.masked_fill(~attention_mask[..., None].bool(), 0.0)
if self.representation == 'cls':
output = sequence_output[:, 0, :]
elif self.representation == 'mean':
s = torch.sum(sequence_output * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(axis=1, keepdim=True).float()
output = s / d
output = torch.nn.functional.normalize(output, dim=-1)
#output = sequence_output.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
#elif self.representation == -100:
# output = outputs[1]
return output
def save(self, output_dir: str):
state_dict = self.encoder.state_dict()
state_dict = type(state_dict)(
{k: v.clone().cpu()
for k,
v in state_dict.items()})
self.encoder.save_pretrained(output_dir, state_dict=state_dict)
class SharedBiEncoder(nn.Module):
def __init__(self,
model_checkpoint,
encoder=None,
representation='cls',
fixed=False):
super(SharedBiEncoder, self).__init__()
if encoder == None:
encoder = Encoder(model_checkpoint,
representation,
fixed)
self.encoder = encoder
def forward(self,
q_ids,
q_attn_mask,
ctx_ids,
ctx_attn_mask):
q_out = self.encoder.get_representation(q_ids, q_attn_mask)
ctx_out = self.encoder.get_representation(ctx_ids, ctx_attn_mask)
return q_out, ctx_out
def get_model(self):
return self.encoder |