antonlabate
ver 1.3
d758c99
import torch
import torch.utils.data
from seq2struct.models import abstract_preproc
from seq2struct.utils import registry
class ZippedDataset(torch.utils.data.Dataset):
def __init__(self, *components):
assert len(components) >= 1
lengths = [len(c) for c in components]
assert all(
lengths[0] == other for other in lengths[1:]), "Lengths don't match: {}".format(lengths)
self.components = components
def __getitem__(self, idx):
return tuple(c[idx] for c in self.components)
def __len__(self):
return len(self.components[0])
@registry.register('model', 'EncDec')
class EncDecModel(torch.nn.Module):
class Preproc(abstract_preproc.AbstractPreproc):
def __init__(
self,
encoder,
decoder,
encoder_preproc,
decoder_preproc):
super().__init__()
self.enc_preproc = registry.lookup('encoder', encoder['name']).Preproc(**encoder_preproc)
self.dec_preproc = registry.lookup('decoder', decoder['name']).Preproc(**decoder_preproc)
def validate_item(self, item, section):
enc_result, enc_info = self.enc_preproc.validate_item(item, section)
dec_result, dec_info = self.dec_preproc.validate_item(item, section)
return enc_result and dec_result, (enc_info, dec_info)
def add_item(self, item, section, validation_info):
enc_info, dec_info = validation_info
self.enc_preproc.add_item(item, section, enc_info)
self.dec_preproc.add_item(item, section, dec_info)
def clear_items(self):
self.enc_preproc.clear_items()
self.dec_preproc.clear_items()
def save(self):
self.enc_preproc.save()
self.dec_preproc.save()
def load(self):
self.enc_preproc.load()
self.dec_preproc.load()
def dataset(self, section):
return ZippedDataset(self.enc_preproc.dataset(section), self.dec_preproc.dataset(section))
def __init__(self, preproc, device, encoder, decoder):
super().__init__()
self.preproc = preproc
self.encoder = registry.construct(
'encoder', encoder, device=device, preproc=preproc.enc_preproc)
self.decoder = registry.construct(
'decoder', decoder, device=device, preproc=preproc.dec_preproc)
self.decoder.visualize_flag = False
if getattr(self.encoder, 'batched'):
self.compute_loss = self._compute_loss_enc_batched
else:
self.compute_loss = self._compute_loss_unbatched
def _compute_loss_enc_batched(self, batch, debug=False):
losses = []
d = [enc_input for enc_input, dec_output in batch]
enc_states = self.encoder(d)
for enc_state, (enc_input, dec_output) in zip(enc_states, batch):
loss = self.decoder.compute_loss(enc_input, dec_output, enc_state, debug)
losses.append(loss)
if debug:
return losses
else:
return torch.mean(torch.stack(losses, dim=0), dim=0)
def _compute_loss_enc_batched2(self, batch, debug=False):
losses = []
for enc_input, dec_output in batch:
enc_state, = self.encoder([enc_input])
loss = self.decoder.compute_loss(enc_input, dec_output, enc_state, debug)
losses.append(loss)
if debug:
return losses
else:
return torch.mean(torch.stack(losses, dim=0), dim=0)
def _compute_loss_unbatched(self, batch, debug=False):
losses = []
for enc_input, dec_output in batch:
enc_state = self.encoder(enc_input)
loss = self.decoder.compute_loss(enc_input, dec_output, enc_state, debug)
losses.append(loss)
if debug:
return losses
else:
return torch.mean(torch.stack(losses, dim=0), dim=0)
def eval_on_batch(self, batch):
mean_loss = self.compute_loss(batch).item()
batch_size = len(batch)
result = {'loss': mean_loss * batch_size, 'total': batch_size}
return result
def begin_inference(self, orig_item, preproc_item):
## TODO: Don't hardcode train
#valid, validation_info = self.preproc.enc_preproc.validate_item(item, 'train')
#if not valid:
# return None
#enc_input = self.preproc.enc_preproc.preprocess_item(item, validation_info)
enc_input, _ = preproc_item
if self.decoder.visualize_flag:
print('question:')
print(enc_input['question'])
print('columns:')
print(enc_input['columns'])
print('tables:')
print(enc_input['tables'])
if getattr(self.encoder, 'batched'):
enc_state, = self.encoder([enc_input])
else:
enc_state = self.encoder(enc_input)
return self.decoder.begin_inference(enc_state, orig_item)