File size: 1,025 Bytes
eb339cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from typing import Optional, Union
import torch
import diffusers
class LCMScheduler(diffusers.schedulers.LCMScheduler):
def __init__(self, timesteps_step_map: Optional[dict] = None, **kwargs) -> None:
super(LCMScheduler, self).__init__(**kwargs)
self.timesteps_step_map = timesteps_step_map
def set_timesteps(self, num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None, **kwargs) -> None:
if self.timesteps_step_map is None:
super().set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
else:
assert num_inference_steps is not None
self.num_inference_steps = num_inference_steps
timesteps = self.timesteps_step_map[num_inference_steps]
assert all([timestep < self.config.num_train_timesteps for timestep in timesteps])
self.timesteps = torch.tensor(timesteps).to(device=device, dtype=torch.long)
self._step_index = None
|