File size: 4,030 Bytes
d758c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import os

import attr
import torch

from seq2struct.models import abstract_preproc
from seq2struct.models import variational_lstm
from seq2struct.utils import registry
from seq2struct.utils import vocab

@attr.s
class NL2CodeEncoderState:
    state = attr.ib()
    memory = attr.ib()
    words = attr.ib()

    def find_word_occurrences(self, word):
        return [i for i, w in enumerate(self.words) if w == word]


@registry.register('encoder', 'NL2Code')
class NL2CodeEncoder(torch.nn.Module):

    batched = False

    class Preproc(abstract_preproc.AbstractPreproc):
        def __init__(
                self,
                save_path,
                min_freq=3,
                max_count=5000):
            self.vocab_path = os.path.join(save_path, 'enc_vocab.json')
            self.data_dir = os.path.join(save_path, 'enc')

            self.vocab_builder = vocab.VocabBuilder(min_freq, max_count)
            self.init_items()
            self.vocab = None
        
        def init_items(self):
            # TODO: Write 'train', 'val', 'test' somewhere else
            self.texts = {'train': [], 'val': [], 'test': []}


        def validate_item(self, item, section):
            return True, None
        
        def add_item(self, item, section, validation_info):
            if section == 'train':
                for token in item.text:
                    self.vocab_builder.add_word(token)
            self.texts[section].append(item.text)

        def clear_items(self):
            self.init_items()

        def preprocess_item(self, item, validation_info):
            return item.text

        def save(self):
            os.makedirs(self.data_dir, exist_ok=True)
            self.vocab = self.vocab_builder.finish()
            self.vocab.save(self.vocab_path)

            for section, texts in self.texts.items():
                with open(os.path.join(self.data_dir, section + '.jsonl'), 'w', encoding='utf8') as f:
                    for text in texts:
                        f.write(json.dumps(text, ensure_ascii=False) + '\n')

        def load(self):
            self.vocab = vocab.Vocab.load(self.vocab_path)

        def dataset(self, section):
            return [
                json.loads(line)
                for line in open(os.path.join(self.data_dir, section + '.jsonl'), encoding='utf8')]

    def __init__(
            self,
            device,
            preproc,
            word_emb_size=128,
            recurrent_size=256,
            dropout=0.):
        super().__init__()
        self._device = device
        self.desc_vocab = preproc.vocab

        self.word_emb_size = word_emb_size
        self.recurrent_size = recurrent_size
        assert self.recurrent_size % 2 == 0

        self.desc_embedding = torch.nn.Embedding(
                num_embeddings=len(self.desc_vocab),
                embedding_dim=self.word_emb_size)
        self.encoder = variational_lstm.LSTM(
                input_size=self.word_emb_size,
                hidden_size=self.recurrent_size // 2,
                bidirectional=True,
                dropout=dropout)

    def forward(self, desc_words):
        # desc_indices shape: batch (=1) x desc length
        desc_indices = torch.tensor(
                self.desc_vocab.indices(desc_words),
                device=self._device).unsqueeze(0)
        # desc_emb shape: batch (=1) x desc length x word_emb_size
        desc_emb = self.desc_embedding(desc_indices)
        # desc_emb shape: desc length x batch (=1) x word_emb_size
        desc_emb = desc_emb.transpose(0, 1)

        # outputs shape: desc length x batch (=1) x recurrent_size
        # state shape:
        # - h: num_layers (=1) * num_directions (=2) x batch (=1) x recurrent_size / 2
        # - c: num_layers (=1) * num_directions (=2) x batch (=1) x recurrent_size / 2
        outputs, state = self.encoder(desc_emb)

        return NL2CodeEncoderState(
            state=state,
            memory=outputs.transpose(0, 1),
            words=desc_words)