Spaces:
Build error
Build error
from torch.nn.parallel import DataParallel | |
import torch | |
from torch.nn.parallel._functions import Scatter | |
from torch.nn.parallel.parallel_apply import parallel_apply | |
def scatter(inputs, target_gpus, chunk_sizes, dim=0): | |
r""" | |
Slices tensors into approximately equal chunks and | |
distributes them across given GPUs. Duplicates | |
references to objects that are not tensors. | |
""" | |
def scatter_map(obj): | |
if isinstance(obj, torch.Tensor): | |
try: | |
return Scatter.apply(target_gpus, chunk_sizes, dim, obj) | |
except: | |
print('obj', obj.size()) | |
print('dim', dim) | |
print('chunk_sizes', chunk_sizes) | |
quit() | |
if isinstance(obj, tuple) and len(obj) > 0: | |
return list(zip(*map(scatter_map, obj))) | |
if isinstance(obj, list) and len(obj) > 0: | |
return list(map(list, zip(*map(scatter_map, obj)))) | |
if isinstance(obj, dict) and len(obj) > 0: | |
return list(map(type(obj), zip(*map(scatter_map, obj.items())))) | |
return [obj for targets in target_gpus] | |
# After scatter_map is called, a scatter_map cell will exist. This cell | |
# has a reference to the actual function scatter_map, which has references | |
# to a closure that has a reference to the scatter_map cell (because the | |
# fn is recursive). To avoid this reference cycle, we set the function to | |
# None, clearing the cell | |
try: | |
return scatter_map(inputs) | |
finally: | |
scatter_map = None | |
def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): | |
r"""Scatter with support for kwargs dictionary""" | |
inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] | |
kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] | |
if len(inputs) < len(kwargs): | |
inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) | |
elif len(kwargs) < len(inputs): | |
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) | |
inputs = tuple(inputs) | |
kwargs = tuple(kwargs) | |
return inputs, kwargs | |
class BalancedDataParallel(DataParallel): | |
def __init__(self, gpu0_bsz, *args, **kwargs): | |
self.gpu0_bsz = gpu0_bsz | |
super().__init__(*args, **kwargs) | |
def forward(self, *inputs, **kwargs): | |
if not self.device_ids: | |
return self.module(*inputs, **kwargs) | |
if self.gpu0_bsz == 0: | |
device_ids = self.device_ids[1:] | |
else: | |
device_ids = self.device_ids | |
inputs, kwargs = self.scatter(inputs, kwargs, device_ids) | |
# print('len(inputs)1: ', str(len(inputs))) | |
# print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)])) | |
if len(self.device_ids) == 1: | |
return self.module(*inputs[0], **kwargs[0]) | |
replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) | |
if self.gpu0_bsz == 0: | |
replicas = replicas[1:] | |
outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) | |
return self.gather(outputs, self.output_device) | |
def parallel_apply(self, replicas, device_ids, inputs, kwargs): | |
return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)]) | |
def scatter(self, inputs, kwargs, device_ids): | |
bsz = inputs[0].size(self.dim) | |
num_dev = len(self.device_ids) | |
gpu0_bsz = self.gpu0_bsz | |
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) | |
if gpu0_bsz < bsz_unit: | |
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) | |
delta = bsz - sum(chunk_sizes) | |
for i in range(delta): | |
chunk_sizes[i + 1] += 1 | |
if gpu0_bsz == 0: | |
chunk_sizes = chunk_sizes[1:] | |
else: | |
return super().scatter(inputs, kwargs, device_ids) | |
# print('bsz: ', bsz) | |
# print('num_dev: ', num_dev) | |
# print('gpu0_bsz: ', gpu0_bsz) | |
# print('bsz_unit: ', bsz_unit) | |
# print('chunk_sizes: ', chunk_sizes) | |
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) | |