File size: 525 Bytes
8044721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#!/usr/bin/env python3
# coding=utf-8

import torch
from data.field.mini_torchtext.field import RawField


class AnchorField(RawField):
    def process(self, batch, device=None):
        tensors, masks = self.pad(batch, device)
        return tensors, masks

    def pad(self, anchors, device):
        tensor = torch.zeros(anchors[0], anchors[1], dtype=torch.long, device=device)
        for anchor in anchors[-1]:
            tensor[anchor[0], anchor[1]] = 1
        mask = tensor.sum(-1) == 0

        return tensor, mask