File size: 5,128 Bytes
d758c99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
import torch.utils.data
from seq2struct.models import abstract_preproc
from seq2struct.utils import registry
class ZippedDataset(torch.utils.data.Dataset):
def __init__(self, *components):
assert len(components) >= 1
lengths = [len(c) for c in components]
assert all(
lengths[0] == other for other in lengths[1:]), "Lengths don't match: {}".format(lengths)
self.components = components
def __getitem__(self, idx):
return tuple(c[idx] for c in self.components)
def __len__(self):
return len(self.components[0])
@registry.register('model', 'EncDec')
class EncDecModel(torch.nn.Module):
class Preproc(abstract_preproc.AbstractPreproc):
def __init__(
self,
encoder,
decoder,
encoder_preproc,
decoder_preproc):
super().__init__()
self.enc_preproc = registry.lookup('encoder', encoder['name']).Preproc(**encoder_preproc)
self.dec_preproc = registry.lookup('decoder', decoder['name']).Preproc(**decoder_preproc)
def validate_item(self, item, section):
enc_result, enc_info = self.enc_preproc.validate_item(item, section)
dec_result, dec_info = self.dec_preproc.validate_item(item, section)
return enc_result and dec_result, (enc_info, dec_info)
def add_item(self, item, section, validation_info):
enc_info, dec_info = validation_info
self.enc_preproc.add_item(item, section, enc_info)
self.dec_preproc.add_item(item, section, dec_info)
def clear_items(self):
self.enc_preproc.clear_items()
self.dec_preproc.clear_items()
def save(self):
self.enc_preproc.save()
self.dec_preproc.save()
def load(self):
self.enc_preproc.load()
self.dec_preproc.load()
def dataset(self, section):
return ZippedDataset(self.enc_preproc.dataset(section), self.dec_preproc.dataset(section))
def __init__(self, preproc, device, encoder, decoder):
super().__init__()
self.preproc = preproc
self.encoder = registry.construct(
'encoder', encoder, device=device, preproc=preproc.enc_preproc)
self.decoder = registry.construct(
'decoder', decoder, device=device, preproc=preproc.dec_preproc)
self.decoder.visualize_flag = False
if getattr(self.encoder, 'batched'):
self.compute_loss = self._compute_loss_enc_batched
else:
self.compute_loss = self._compute_loss_unbatched
def _compute_loss_enc_batched(self, batch, debug=False):
losses = []
d = [enc_input for enc_input, dec_output in batch]
enc_states = self.encoder(d)
for enc_state, (enc_input, dec_output) in zip(enc_states, batch):
loss = self.decoder.compute_loss(enc_input, dec_output, enc_state, debug)
losses.append(loss)
if debug:
return losses
else:
return torch.mean(torch.stack(losses, dim=0), dim=0)
def _compute_loss_enc_batched2(self, batch, debug=False):
losses = []
for enc_input, dec_output in batch:
enc_state, = self.encoder([enc_input])
loss = self.decoder.compute_loss(enc_input, dec_output, enc_state, debug)
losses.append(loss)
if debug:
return losses
else:
return torch.mean(torch.stack(losses, dim=0), dim=0)
def _compute_loss_unbatched(self, batch, debug=False):
losses = []
for enc_input, dec_output in batch:
enc_state = self.encoder(enc_input)
loss = self.decoder.compute_loss(enc_input, dec_output, enc_state, debug)
losses.append(loss)
if debug:
return losses
else:
return torch.mean(torch.stack(losses, dim=0), dim=0)
def eval_on_batch(self, batch):
mean_loss = self.compute_loss(batch).item()
batch_size = len(batch)
result = {'loss': mean_loss * batch_size, 'total': batch_size}
return result
def begin_inference(self, orig_item, preproc_item):
## TODO: Don't hardcode train
#valid, validation_info = self.preproc.enc_preproc.validate_item(item, 'train')
#if not valid:
# return None
#enc_input = self.preproc.enc_preproc.preprocess_item(item, validation_info)
enc_input, _ = preproc_item
if self.decoder.visualize_flag:
print('question:')
print(enc_input['question'])
print('columns:')
print(enc_input['columns'])
print('tables:')
print(enc_input['tables'])
if getattr(self.encoder, 'batched'):
enc_state, = self.encoder([enc_input])
else:
enc_state = self.encoder(enc_input)
return self.decoder.begin_inference(enc_state, orig_item)
|