Spaces:
Runtime error
Runtime error
import importlib.util | |
import logging | |
import warnings | |
import importlib_metadata | |
from packaging import version | |
logger = logging.getLogger(__name__) | |
_xformers_available = importlib.util.find_spec("xformers") is not None | |
try: | |
if _xformers_available: | |
_xformers_version = importlib_metadata.version("xformers") | |
_torch_version = importlib_metadata.version("torch") | |
if version.Version(_torch_version) < version.Version("1.12"): | |
raise ValueError("xformers is installed but requires PyTorch >= 1.12") | |
logger.debug(f"Successfully imported xformers version {_xformers_version}") | |
except importlib_metadata.PackageNotFoundError: | |
_xformers_available = False | |
_triton_modules_available = importlib.util.find_spec("triton") is not None | |
try: | |
if _triton_modules_available: | |
_triton_version = importlib_metadata.version("triton") | |
if version.Version(_triton_version) < version.Version("3.0.0"): | |
raise ValueError("triton is installed but requires Triton >= 3.0.0") | |
logger.debug(f"Successfully imported triton version {_triton_version}") | |
except ImportError: | |
_triton_modules_available = False | |
warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.") | |
def is_xformers_available(): | |
return _xformers_available | |
def is_triton_module_available(): | |
return _triton_modules_available | |