Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,118 Bytes
8cd00a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import webdataset as wds
import os
import torch
class ActivationsDataloader:
def __init__(self, paths_to_datasets, block_name, batch_size, output_or_diff='diff', num_in_buffer=50):
assert output_or_diff in ['diff', 'output'], "Provide 'output' or 'diff'"
self.dataset = wds.WebDataset(
[os.path.join(path_to_dataset, f"{block_name}.tar")
for path_to_dataset in paths_to_datasets]
).decode("torch")
self.iter = iter(self.dataset)
self.buffer = None
self.pointer = 0
self.num_in_buffer = num_in_buffer
self.output_or_diff = output_or_diff
self.batch_size = batch_size
self.one_size = None
def renew_buffer(self, to_retrieve):
to_merge = []
if self.buffer is not None and self.buffer.shape[0] > self.pointer:
to_merge = [self.buffer[self.pointer:].clone()]
del self.buffer
for _ in range(to_retrieve):
sample = next(self.iter)
latents = sample['output.pth'] if self.output_or_diff == 'output' else sample['diff.pth']
latents = latents.permute((0, 1, 3, 4, 2))
latents = latents.reshape((-1, latents.shape[-1]))
to_merge.append(latents.to('cuda'))
self.one_size = latents.shape[0]
self.buffer = torch.cat(to_merge, dim=0)
shuffled_indices = torch.randperm(self.buffer.shape[0])
self.buffer = self.buffer[shuffled_indices]
self.pointer = 0
def iterate(self):
while True:
if self.buffer == None or self.buffer.shape[0] - self.pointer < self.num_in_buffer * self.one_size * 4 // 5:
try:
to_retrieve = self.num_in_buffer if self.buffer is None else self.num_in_buffer // 5
self.renew_buffer(to_retrieve)
except StopIteration:
break
batch = self.buffer[self.pointer: self.pointer + self.batch_size]
self.pointer += self.batch_size
assert batch.shape[0] == self.batch_size
yield batch
|