|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from model.head.abstract_head import AbstractHead |
|
from data.parser.to_mrp.labeled_edge_parser import LabeledEdgeParser |
|
from utility.cross_entropy import binary_cross_entropy |
|
from utility.hungarian_matching import match_label |
|
|
|
|
|
class LabeledEdgeHead(AbstractHead): |
|
def __init__(self, dataset, args, initialize): |
|
config = { |
|
"label": True, |
|
"edge presence": True, |
|
"edge label": True, |
|
"anchor": True, |
|
"source_anchor": False, |
|
"target_anchor": False |
|
} |
|
super(LabeledEdgeHead, self).__init__(dataset, args, config, initialize) |
|
|
|
self.top_node = nn.Parameter(torch.randn(1, 1, args.hidden_size), requires_grad=True) |
|
self.parser = LabeledEdgeParser(dataset) |
|
|
|
def init_label_classifier(self, dataset, args, config, initialize: bool): |
|
classifier = nn.Sequential( |
|
nn.Dropout(args.dropout_label), |
|
nn.Linear(args.hidden_size, 1, bias=True) |
|
) |
|
if initialize: |
|
bias_init = torch.tensor([dataset.label_freqs[1]]) |
|
classifier[1].bias.data = (bias_init / (1.0 - bias_init)).log() |
|
|
|
return classifier |
|
|
|
def forward_label(self, decoder_output): |
|
return self.label_classifier(decoder_output) |
|
|
|
def forward_edge(self, decoder_output): |
|
top_node = self.top_node.expand(decoder_output.size(0), -1, -1) |
|
decoder_output = torch.cat([top_node, decoder_output], dim=1) |
|
return self.edge_classifier(decoder_output) |
|
|
|
def loss_label(self, prediction, target, mask, matching): |
|
prediction = prediction["label"] |
|
target = match_label( |
|
target["labels"][0], matching, prediction.shape[:-1], prediction.device, self.query_length |
|
) |
|
return {"label": binary_cross_entropy(prediction.squeeze(-1), target.float(), mask, focal=self.focal)} |
|
|
|
def inference_label(self, prediction): |
|
return (prediction.squeeze(-1) > 0.0).long() |
|
|
|
def label_cost_matrix(self, output, batch, decoder_lens, b: int): |
|
if output["label"] is None: |
|
return 1.0 |
|
|
|
target_labels = batch["anchored_labels"][b] |
|
label_prob = output["label"][b, : decoder_lens[b], :].sigmoid().unsqueeze(0) |
|
label_prob = torch.cat([1.0 - label_prob, label_prob], dim=-1) |
|
tgt_label = target_labels.repeat_interleave(self.query_length, dim=1) |
|
cost_matrix = ((tgt_label * label_prob).sum(-1) * label_prob[:, :, 1:].sum(-1)).t().sqrt() |
|
|
|
return cost_matrix |
|
|