File size: 440 Bytes
8044721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#!/usr/bin/env python3
# coding=utf-8

import torch
from data.field.mini_torchtext.field import RawField


class BertField(RawField):
    def __init__(self):
        super(BertField, self).__init__()

    def process(self, example, device=None):
        attention_mask = [1] * len(example)

        example = torch.LongTensor(example, device=device)
        attention_mask = torch.ones_like(example)

        return example, attention_mask