|
"""k-diffusion transformer diffusion models, version 2. |
|
Codes adopted from https://github.com/crowsonkb/k-diffusion |
|
""" |
|
|
|
from contextlib import contextmanager |
|
from functools import update_wrapper |
|
import os |
|
import threading |
|
|
|
import torch |
|
|
|
|
|
def get_use_compile(): |
|
return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1" |
|
|
|
|
|
def get_use_flash_attention_2(): |
|
return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1" |
|
|
|
|
|
state = threading.local() |
|
state.checkpointing = False |
|
|
|
|
|
@contextmanager |
|
def checkpointing(enable=True): |
|
try: |
|
old_checkpointing, state.checkpointing = state.checkpointing, enable |
|
yield |
|
finally: |
|
state.checkpointing = old_checkpointing |
|
|
|
|
|
def get_checkpointing(): |
|
return getattr(state, "checkpointing", False) |
|
|
|
|
|
class compile_wrap: |
|
def __init__(self, function, *args, **kwargs): |
|
self.function = function |
|
self.args = args |
|
self.kwargs = kwargs |
|
self._compiled_function = None |
|
update_wrapper(self, function) |
|
|
|
@property |
|
def compiled_function(self): |
|
if self._compiled_function is not None: |
|
return self._compiled_function |
|
if get_use_compile(): |
|
try: |
|
self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs) |
|
except RuntimeError: |
|
self._compiled_function = self.function |
|
else: |
|
self._compiled_function = self.function |
|
return self._compiled_function |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.compiled_function(*args, **kwargs) |