Spaces:
Paused
Paused
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Import utilities: Utilities related to imports and our lazy inits. | |
""" | |
import importlib.metadata as importlib_metadata | |
import importlib.util | |
import os | |
from itertools import chain | |
from types import ModuleType | |
from typing import Any | |
from . import logging | |
logger = logging.get_logger(__name__) | |
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} | |
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) | |
USE_TF = os.environ.get("USE_TF", "AUTO").upper() | |
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() | |
IMAGE_AUX_SLOW_IMPORT = os.environ.get("IMAGE_AUX_SLOW_IMPORT", "FALSE").upper() | |
IMAGE_AUX_SLOW_IMPORT = IMAGE_AUX_SLOW_IMPORT in ENV_VARS_TRUE_VALUES | |
_torch_version = "N/A" | |
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: | |
_torch_available = importlib.util.find_spec("torch") is not None | |
if _torch_available: | |
try: | |
_torch_version = importlib_metadata.version("torch") | |
logger.info(f"PyTorch version {_torch_version} available.") | |
except importlib_metadata.PackageNotFoundError: | |
_torch_available = False | |
else: | |
logger.info("Disabling PyTorch because USE_TORCH is set") | |
_torch_available = False | |
_transformers_available = importlib.util.find_spec("transformers") is not None | |
try: | |
_transformers_version = importlib_metadata.version("transformers") | |
logger.debug(f"Successfully imported transformers version {_transformers_version}") | |
except importlib_metadata.PackageNotFoundError: | |
_transformers_available = False | |
def is_torch_available(): | |
return _torch_available | |
def is_transformers_available(): | |
return _transformers_available | |
class OptionalDependencyNotAvailable(BaseException): | |
"""An error indicating that an optional dependency of Diffusers was not found in the environment.""" | |
class _LazyModule(ModuleType): | |
""" | |
Module class that surfaces all objects but only performs associated imports when the objects are requested. | |
""" | |
# Very heavily inspired by optuna.integration._IntegrationModule | |
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py | |
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None): | |
super().__init__(name) | |
self._modules = set(import_structure.keys()) | |
self._class_to_module = {} | |
for key, values in import_structure.items(): | |
for value in values: | |
self._class_to_module[value] = key | |
# Needed for autocompletion in an IDE | |
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) | |
self.__file__ = module_file | |
self.__spec__ = module_spec | |
self.__path__ = [os.path.dirname(module_file)] | |
self._objects = {} if extra_objects is None else extra_objects | |
self._name = name | |
self._import_structure = import_structure | |
# Needed for autocompletion in an IDE | |
def __dir__(self): | |
result = super().__dir__() | |
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether | |
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. | |
for attr in self.__all__: | |
if attr not in result: | |
result.append(attr) | |
return result | |
def __getattr__(self, name: str) -> Any: | |
if name in self._objects: | |
return self._objects[name] | |
if name in self._modules: | |
value = self._get_module(name) | |
elif name in self._class_to_module.keys(): | |
module = self._get_module(self._class_to_module[name]) | |
value = getattr(module, name) | |
else: | |
raise AttributeError(f"module {self.__name__} has no attribute {name}") | |
setattr(self, name, value) | |
return value | |
def _get_module(self, module_name: str): | |
try: | |
return importlib.import_module("." + module_name, self.__name__) | |
except Exception as e: | |
raise RuntimeError( | |
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its" | |
f" traceback):\n{e}" | |
) from e | |
def __reduce__(self): | |
return (self.__class__, (self._name, self.__file__, self._import_structure)) | |