ssa-perin / data /batch.py
larkkin's picture
Add supporting code from perin
7daaa6b
#!/usr/bin/env python3
# coding=utf-8
import torch
import torch.nn.functional as F
class Batch:
@staticmethod
def build(data):
fields = list(data[0].keys())
transposed = {}
for field in fields:
if isinstance(data[0][field], tuple):
transposed[field] = tuple(Batch._stack(field, [example[field][i] for example in data]) for i in range(len(data[0][field])))
else:
transposed[field] = Batch._stack(field, [example[field] for example in data])
return transposed
@staticmethod
def _stack(field: str, examples):
if field == "anchored_labels":
return examples
dim = examples[0].dim()
if dim == 0:
return torch.stack(examples)
lengths = [max(example.size(i) for example in examples) for i in range(dim)]
if any(length == 0 for length in lengths):
return torch.LongTensor(len(examples), *lengths)
examples = [F.pad(example, Batch._pad_size(example, lengths)) for example in examples]
return torch.stack(examples)
@staticmethod
def _pad_size(example, total_size):
return [p for i, l in enumerate(total_size[::-1]) for p in (0, l - example.size(-1 - i))]
@staticmethod
def index_select(batch, indices):
filtered_batch = {}
for key, examples in batch.items():
if isinstance(examples, list) or isinstance(examples, tuple):
filtered_batch[key] = [example.index_select(0, indices) for example in examples]
else:
filtered_batch[key] = examples.index_select(0, indices)
return filtered_batch
@staticmethod
def to_str(batch):
string = "\n".join([f"\t{name}: {Batch._short_str(item)}" for name, item in batch.items()])
return string
@staticmethod
def to(batch, device):
converted = {}
for field in batch.keys():
converted[field] = Batch._to(batch[field], device)
return converted
@staticmethod
def _short_str(tensor):
# unwrap variable to tensor
if not torch.is_tensor(tensor):
# (1) unpack variable
if hasattr(tensor, "data"):
tensor = getattr(tensor, "data")
# (2) handle include_lengths
elif isinstance(tensor, tuple) or isinstance(tensor, list):
return str(tuple(Batch._short_str(t) for t in tensor))
# (3) fallback to default str
else:
return str(tensor)
# copied from torch _tensor_str
size_str = "x".join(str(size) for size in tensor.size())
device_str = "" if not tensor.is_cuda else " (GPU {})".format(tensor.get_device())
strt = "[{} of size {}{}]".format(torch.typename(tensor), size_str, device_str)
return strt
@staticmethod
def _to(tensor, device):
if not torch.is_tensor(tensor):
if isinstance(tensor, tuple):
return tuple(Batch._to(t, device) for t in tensor)
elif isinstance(tensor, list):
return [Batch._to(t, device) for t in tensor]
else:
raise Exception(f"unsupported type of {tensor} to be casted to cuda")
return tensor.to(device, non_blocking=True)