|
import json |
|
import os |
|
|
|
import attr |
|
import torch |
|
|
|
from seq2struct.models import abstract_preproc |
|
from seq2struct.models import variational_lstm |
|
from seq2struct.utils import registry |
|
from seq2struct.utils import vocab |
|
|
|
@attr.s |
|
class NL2CodeEncoderState: |
|
state = attr.ib() |
|
memory = attr.ib() |
|
words = attr.ib() |
|
|
|
def find_word_occurrences(self, word): |
|
return [i for i, w in enumerate(self.words) if w == word] |
|
|
|
|
|
@registry.register('encoder', 'NL2Code') |
|
class NL2CodeEncoder(torch.nn.Module): |
|
|
|
batched = False |
|
|
|
class Preproc(abstract_preproc.AbstractPreproc): |
|
def __init__( |
|
self, |
|
save_path, |
|
min_freq=3, |
|
max_count=5000): |
|
self.vocab_path = os.path.join(save_path, 'enc_vocab.json') |
|
self.data_dir = os.path.join(save_path, 'enc') |
|
|
|
self.vocab_builder = vocab.VocabBuilder(min_freq, max_count) |
|
self.init_items() |
|
self.vocab = None |
|
|
|
def init_items(self): |
|
|
|
self.texts = {'train': [], 'val': [], 'test': []} |
|
|
|
|
|
def validate_item(self, item, section): |
|
return True, None |
|
|
|
def add_item(self, item, section, validation_info): |
|
if section == 'train': |
|
for token in item.text: |
|
self.vocab_builder.add_word(token) |
|
self.texts[section].append(item.text) |
|
|
|
def clear_items(self): |
|
self.init_items() |
|
|
|
def preprocess_item(self, item, validation_info): |
|
return item.text |
|
|
|
def save(self): |
|
os.makedirs(self.data_dir, exist_ok=True) |
|
self.vocab = self.vocab_builder.finish() |
|
self.vocab.save(self.vocab_path) |
|
|
|
for section, texts in self.texts.items(): |
|
with open(os.path.join(self.data_dir, section + '.jsonl'), 'w', encoding='utf8') as f: |
|
for text in texts: |
|
f.write(json.dumps(text, ensure_ascii=False) + '\n') |
|
|
|
def load(self): |
|
self.vocab = vocab.Vocab.load(self.vocab_path) |
|
|
|
def dataset(self, section): |
|
return [ |
|
json.loads(line) |
|
for line in open(os.path.join(self.data_dir, section + '.jsonl'), encoding='utf8')] |
|
|
|
def __init__( |
|
self, |
|
device, |
|
preproc, |
|
word_emb_size=128, |
|
recurrent_size=256, |
|
dropout=0.): |
|
super().__init__() |
|
self._device = device |
|
self.desc_vocab = preproc.vocab |
|
|
|
self.word_emb_size = word_emb_size |
|
self.recurrent_size = recurrent_size |
|
assert self.recurrent_size % 2 == 0 |
|
|
|
self.desc_embedding = torch.nn.Embedding( |
|
num_embeddings=len(self.desc_vocab), |
|
embedding_dim=self.word_emb_size) |
|
self.encoder = variational_lstm.LSTM( |
|
input_size=self.word_emb_size, |
|
hidden_size=self.recurrent_size // 2, |
|
bidirectional=True, |
|
dropout=dropout) |
|
|
|
def forward(self, desc_words): |
|
|
|
desc_indices = torch.tensor( |
|
self.desc_vocab.indices(desc_words), |
|
device=self._device).unsqueeze(0) |
|
|
|
desc_emb = self.desc_embedding(desc_indices) |
|
|
|
desc_emb = desc_emb.transpose(0, 1) |
|
|
|
|
|
|
|
|
|
|
|
outputs, state = self.encoder(desc_emb) |
|
|
|
return NL2CodeEncoderState( |
|
state=state, |
|
memory=outputs.transpose(0, 1), |
|
words=desc_words) |
|
|