Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# coding=utf-8 | |
import torch | |
import torch.nn.functional as F | |
class Batch: | |
def build(data): | |
fields = list(data[0].keys()) | |
transposed = {} | |
for field in fields: | |
if isinstance(data[0][field], tuple): | |
transposed[field] = tuple(Batch._stack(field, [example[field][i] for example in data]) for i in range(len(data[0][field]))) | |
else: | |
transposed[field] = Batch._stack(field, [example[field] for example in data]) | |
return transposed | |
def _stack(field: str, examples): | |
if field == "anchored_labels": | |
return examples | |
dim = examples[0].dim() | |
if dim == 0: | |
return torch.stack(examples) | |
lengths = [max(example.size(i) for example in examples) for i in range(dim)] | |
if any(length == 0 for length in lengths): | |
return torch.LongTensor(len(examples), *lengths) | |
examples = [F.pad(example, Batch._pad_size(example, lengths)) for example in examples] | |
return torch.stack(examples) | |
def _pad_size(example, total_size): | |
return [p for i, l in enumerate(total_size[::-1]) for p in (0, l - example.size(-1 - i))] | |
def index_select(batch, indices): | |
filtered_batch = {} | |
for key, examples in batch.items(): | |
if isinstance(examples, list) or isinstance(examples, tuple): | |
filtered_batch[key] = [example.index_select(0, indices) for example in examples] | |
else: | |
filtered_batch[key] = examples.index_select(0, indices) | |
return filtered_batch | |
def to_str(batch): | |
string = "\n".join([f"\t{name}: {Batch._short_str(item)}" for name, item in batch.items()]) | |
return string | |
def to(batch, device): | |
converted = {} | |
for field in batch.keys(): | |
converted[field] = Batch._to(batch[field], device) | |
return converted | |
def _short_str(tensor): | |
# unwrap variable to tensor | |
if not torch.is_tensor(tensor): | |
# (1) unpack variable | |
if hasattr(tensor, "data"): | |
tensor = getattr(tensor, "data") | |
# (2) handle include_lengths | |
elif isinstance(tensor, tuple) or isinstance(tensor, list): | |
return str(tuple(Batch._short_str(t) for t in tensor)) | |
# (3) fallback to default str | |
else: | |
return str(tensor) | |
# copied from torch _tensor_str | |
size_str = "x".join(str(size) for size in tensor.size()) | |
device_str = "" if not tensor.is_cuda else " (GPU {})".format(tensor.get_device()) | |
strt = "[{} of size {}{}]".format(torch.typename(tensor), size_str, device_str) | |
return strt | |
def _to(tensor, device): | |
if not torch.is_tensor(tensor): | |
if isinstance(tensor, tuple): | |
return tuple(Batch._to(t, device) for t in tensor) | |
elif isinstance(tensor, list): | |
return [Batch._to(t, device) for t in tensor] | |
else: | |
raise Exception(f"unsupported type of {tensor} to be casted to cuda") | |
return tensor.to(device, non_blocking=True) | |