Spaces:
Running
on
Zero
Running
on
Zero
from diffusers import DDPMScheduler | |
import torch | |
class HookedNoiseScheduler: | |
scheduler: DDPMScheduler | |
pre_hooks: list | |
post_hooks: list | |
def __init__(self, scheduler): | |
object.__setattr__(self, 'scheduler', scheduler) | |
object.__setattr__(self, 'pre_hooks', []) | |
object.__setattr__(self, 'post_hooks', []) | |
def step( | |
self, | |
model_output, timestep, sample, generator, return_dict | |
): | |
assert return_dict == False, "return_dict == True is not implemented" | |
for hook in self.pre_hooks: | |
hook_output = hook(model_output, timestep, sample, generator) | |
if hook_output is not None: | |
model_output, timestep, sample, generator = hook_output | |
(pred_prev_sample, ) = self.scheduler.step(model_output, timestep, sample, generator, return_dict) | |
for hook in self.post_hooks: | |
hook_output = hook(pred_prev_sample) | |
if hook_output is not None: | |
pred_prev_sample = hook_output | |
return (pred_prev_sample, ) | |
def __getattr__(self, name): | |
return getattr(self.scheduler, name) | |
def __setattr__(self, name, value): | |
if name in {'scheduler', 'pre_hooks', 'post_hooks'}: | |
object.__setattr__(self, name, value) | |
else: | |
setattr(self.scheduler, name, value) |