|
import torch |
|
import torchvision |
|
import numpy as np |
|
import math |
|
import random |
|
import time |
|
|
|
|
|
class Bucketeer: |
|
def __init__( |
|
self, dataloader, |
|
sizes=[(256, 256), (192, 384), (192, 320), (384, 192), (320, 192)], |
|
is_infinite=True, epoch=0, |
|
): |
|
|
|
self.sizes = sizes |
|
self.batch_size = dataloader.batch_size |
|
self._dataloader = dataloader |
|
self.iterator = iter(dataloader) |
|
self.sampler = dataloader.sampler |
|
self.buckets = {s: [] for s in self.sizes} |
|
self.is_infinite = is_infinite |
|
self._epoch = epoch |
|
|
|
def get_available_batch(self): |
|
available_size = [] |
|
for b in self.buckets: |
|
if len(self.buckets[b]) >= self.batch_size: |
|
available_size.append(b) |
|
|
|
if len(available_size) == 0: |
|
return None |
|
else: |
|
b = random.choice(available_size) |
|
batch = self.buckets[b][:self.batch_size] |
|
self.buckets[b] = self.buckets[b][self.batch_size:] |
|
return batch |
|
|
|
def __next__(self): |
|
batch = self.get_available_batch() |
|
while batch is None: |
|
try: |
|
elements = next(self.iterator) |
|
except StopIteration: |
|
|
|
if self.is_infinite: |
|
self._epoch += 1 |
|
if hasattr(self._dataloader.sampler, "set_epoch"): |
|
self._dataloader.sampler.set_epoch(self._epoch) |
|
time.sleep(2) |
|
self.iterator = iter(self._dataloader) |
|
elements = next(self.iterator) |
|
else: |
|
raise StopIteration |
|
|
|
for dct in elements: |
|
try: |
|
img = dct['video'] |
|
size = (img.shape[-1], img.shape[-2]) |
|
self.buckets[size].append({**{'video': img}, **{k:dct[k] for k in dct if k != 'video'}}) |
|
except Exception as e: |
|
continue |
|
|
|
batch = self.get_available_batch() |
|
|
|
out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} |
|
return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __len__(self): |
|
return len(self.iterator) |
|
|
|
|
|
class TemporalLengthBucketeer: |
|
def __init__( |
|
self, dataloader, max_frames=16, epoch=0, |
|
): |
|
self.batch_size = dataloader.batch_size |
|
self._dataloader = dataloader |
|
self.iterator = iter(dataloader) |
|
self.buckets = {temp: [] for temp in range(1, max_frames + 1)} |
|
self._epoch = epoch |
|
|
|
def get_available_batch(self): |
|
available_size = [] |
|
for b in self.buckets: |
|
if len(self.buckets[b]) >= self.batch_size: |
|
available_size.append(b) |
|
|
|
if len(available_size) == 0: |
|
return None |
|
else: |
|
b = random.choice(available_size) |
|
batch = self.buckets[b][:self.batch_size] |
|
self.buckets[b] = self.buckets[b][self.batch_size:] |
|
return batch |
|
|
|
def __next__(self): |
|
batch = self.get_available_batch() |
|
while batch is None: |
|
try: |
|
elements = next(self.iterator) |
|
except StopIteration: |
|
|
|
self._epoch += 1 |
|
if hasattr(self._dataloader.sampler, "set_epoch"): |
|
self._dataloader.sampler.set_epoch(self._epoch) |
|
time.sleep(2) |
|
self.iterator = iter(self._dataloader) |
|
elements = next(self.iterator) |
|
|
|
for dct in elements: |
|
try: |
|
video_latent = dct['video'] |
|
temp = video_latent.shape[2] |
|
self.buckets[temp].append({**{'video': video_latent}, **{k:dct[k] for k in dct if k != 'video'}}) |
|
except Exception as e: |
|
continue |
|
|
|
batch = self.get_available_batch() |
|
|
|
out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} |
|
out = {k: torch.cat(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} |
|
|
|
if 'prompt_embed' in out: |
|
|
|
prompt_embeds = out['prompt_embed'].clone() |
|
del out['prompt_embed'] |
|
prompt_attention_mask = out['prompt_attention_mask'].clone() |
|
del out['prompt_attention_mask'] |
|
pooled_prompt_embeds = out['pooled_prompt_embed'].clone() |
|
del out['pooled_prompt_embed'] |
|
|
|
out['text'] = { |
|
'prompt_embeds' : prompt_embeds, |
|
'prompt_attention_mask': prompt_attention_mask, |
|
'pooled_prompt_embeds': pooled_prompt_embeds, |
|
} |
|
|
|
return out |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __len__(self): |
|
return len(self.iterator) |