File size: 1,583 Bytes
b7f3942 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
"""k-diffusion transformer diffusion models, version 2.
Codes adopted from https://github.com/crowsonkb/k-diffusion
"""
from contextlib import contextmanager
from functools import update_wrapper
import os
import threading
import torch
def get_use_compile():
return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1"
def get_use_flash_attention_2():
return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1"
state = threading.local()
state.checkpointing = False
@contextmanager
def checkpointing(enable=True):
try:
old_checkpointing, state.checkpointing = state.checkpointing, enable
yield
finally:
state.checkpointing = old_checkpointing
def get_checkpointing():
return getattr(state, "checkpointing", False)
class compile_wrap:
def __init__(self, function, *args, **kwargs):
self.function = function
self.args = args
self.kwargs = kwargs
self._compiled_function = None
update_wrapper(self, function)
@property
def compiled_function(self):
if self._compiled_function is not None:
return self._compiled_function
if get_use_compile():
try:
self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs)
except RuntimeError:
self._compiled_function = self.function
else:
self._compiled_function = self.function
return self._compiled_function
def __call__(self, *args, **kwargs):
return self.compiled_function(*args, **kwargs) |