Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# coding=utf-8 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from model.module.biaffine import Biaffine | |
class EdgeClassifier(nn.Module): | |
def __init__(self, dataset, args, initialize: bool, presence: bool, label: bool): | |
super(EdgeClassifier, self).__init__() | |
self.presence = presence | |
if self.presence: | |
if initialize: | |
presence_init = torch.tensor([dataset.edge_presence_freq]) | |
presence_init = (presence_init / (1.0 - presence_init)).log() | |
else: | |
presence_init = None | |
self.edge_presence = EdgeBiaffine( | |
args.hidden_size, args.hidden_size_edge_presence, 1, args.dropout_edge_presence, bias_init=presence_init | |
) | |
self.label = label | |
if self.label: | |
label_init = (dataset.edge_label_freqs / (1.0 - dataset.edge_label_freqs)).log() if initialize else None | |
n_labels = len(dataset.edge_label_field.vocab) | |
self.edge_label = EdgeBiaffine( | |
args.hidden_size, args.hidden_size_edge_label, n_labels, args.dropout_edge_label, bias_init=label_init | |
) | |
def forward(self, x): | |
presence, label = None, None | |
if self.presence: | |
presence = self.edge_presence(x).squeeze(-1) # shape: (B, T, T) | |
if self.label: | |
label = self.edge_label(x) # shape: (B, T, T, O_1) | |
return presence, label | |
class EdgeBiaffine(nn.Module): | |
def __init__(self, hidden_dim, bottleneck_dim, output_dim, dropout, bias_init=None): | |
super(EdgeBiaffine, self).__init__() | |
self.hidden = nn.Linear(hidden_dim, 2 * bottleneck_dim) | |
self.output = Biaffine(bottleneck_dim, output_dim, bias_init=bias_init) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
x = self.dropout(F.elu(self.hidden(x))) # shape: (B, T, 2H) | |
predecessors, current = x.chunk(2, dim=-1) # shape: (B, T, H), (B, T, H) | |
edge = self.output(current, predecessors) # shape: (B, T, T, O) | |
return edge | |