|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils.rnn import pad_sequence |
|
from transformers import AutoModel |
|
from pathlib import Path |
|
|
|
|
|
class LinearTokenSelector(nn.Module): |
|
def __init__(self, encoder, embedding_size=768): |
|
super(LinearTokenSelector, self).__init__() |
|
self.encoder = encoder |
|
self.classifier = nn.Linear(embedding_size, 2, bias=False) |
|
|
|
def forward(self, x): |
|
output = self.encoder(x, output_hidden_states=True) |
|
x = output["hidden_states"][-1] |
|
x = self.classifier(x) |
|
x = F.log_softmax(x, dim=2) |
|
return x |
|
|
|
def save(self, classifier_path, encoder_path): |
|
state = self.state_dict() |
|
state = dict((k, v) for k, v in state.items() if k.startswith("classifier")) |
|
torch.save(state, classifier_path) |
|
self.encoder.save_pretrained(encoder_path) |
|
|
|
def predict(self, texts, tokenizer, device): |
|
input_ids = tokenizer(texts)["input_ids"] |
|
input_ids = pad_sequence( |
|
[torch.tensor(ids) for ids in input_ids], batch_first=True |
|
).to(device) |
|
logits = self.forward(input_ids) |
|
argmax_labels = torch.argmax(logits, dim=2) |
|
return labels_to_summary(input_ids, argmax_labels, tokenizer) |
|
|
|
|
|
def load_model(model_dir, device="cuda", prefix="best"): |
|
if isinstance(model_dir, str): |
|
model_dir = Path(model_dir) |
|
for p in (model_dir / "checkpoints").iterdir(): |
|
if p.name.startswith(f"{prefix}"): |
|
checkpoint_dir = p |
|
return load_checkpoint(checkpoint_dir, device=device) |
|
|
|
|
|
def load_checkpoint(checkpoint_dir, device="cuda"): |
|
if isinstance(checkpoint_dir, str): |
|
checkpoint_dir = Path(checkpoint_dir) |
|
|
|
encoder_path = checkpoint_dir / "encoder.bin" |
|
classifier_path = checkpoint_dir / "classifier.bin" |
|
|
|
encoder = AutoModel.from_pretrained(encoder_path).to(device) |
|
embedding_size = encoder.state_dict()["embeddings.word_embeddings.weight"].shape[1] |
|
|
|
classifier = LinearTokenSelector(None, embedding_size).to(device) |
|
classifier_state = torch.load(classifier_path, map_location=device) |
|
classifier_state = dict( |
|
(k, v) for k, v in classifier_state.items() |
|
if k.startswith("classifier") |
|
) |
|
classifier.load_state_dict(classifier_state) |
|
classifier.encoder = encoder |
|
return classifier.to(device) |
|
|
|
|
|
def labels_to_summary(input_batch, label_batch, tokenizer): |
|
summaries = [] |
|
for input_ids, labels in zip(input_batch, label_batch): |
|
selected = [int(input_ids[i]) for i in range(len(input_ids)) |
|
if labels[i] == 1] |
|
summary = tokenizer.decode(selected, skip_special_tokens=True) |
|
summaries.append(summary) |
|
return summaries |
|
|