Unboxing_SDXL_with_SAEs / SDLens /hooked_scheduler.py
surokpro2's picture
Upload folder using huggingface_hub
8cd00a9 verified
raw
history blame
1.37 kB
from diffusers import DDPMScheduler
import torch
class HookedNoiseScheduler:
scheduler: DDPMScheduler
pre_hooks: list
post_hooks: list
def __init__(self, scheduler):
object.__setattr__(self, 'scheduler', scheduler)
object.__setattr__(self, 'pre_hooks', [])
object.__setattr__(self, 'post_hooks', [])
def step(
self,
model_output, timestep, sample, generator, return_dict
):
assert return_dict == False, "return_dict == True is not implemented"
for hook in self.pre_hooks:
hook_output = hook(model_output, timestep, sample, generator)
if hook_output is not None:
model_output, timestep, sample, generator = hook_output
(pred_prev_sample, ) = self.scheduler.step(model_output, timestep, sample, generator, return_dict)
for hook in self.post_hooks:
hook_output = hook(pred_prev_sample)
if hook_output is not None:
pred_prev_sample = hook_output
return (pred_prev_sample, )
def __getattr__(self, name):
return getattr(self.scheduler, name)
def __setattr__(self, name, value):
if name in {'scheduler', 'pre_hooks', 'post_hooks'}:
object.__setattr__(self, name, value)
else:
setattr(self.scheduler, name, value)