Spaces:
Runtime error
Runtime error
# Copyright (C) 2021-2024, Mindee. | |
# This program is licensed under the Apache License 2.0. | |
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py | |
import importlib.metadata | |
import importlib.util | |
import logging | |
import os | |
from typing import Optional | |
CLASS_NAME: str = "words" | |
__all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"] | |
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} | |
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) | |
USE_TF = os.environ.get("USE_TF", "AUTO").upper() | |
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() | |
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: | |
_torch_available = importlib.util.find_spec("torch") is not None | |
if _torch_available: | |
try: | |
_torch_version = importlib.metadata.version("torch") | |
logging.info(f"PyTorch version {_torch_version} available.") | |
except importlib.metadata.PackageNotFoundError: # pragma: no cover | |
_torch_available = False | |
else: # pragma: no cover | |
logging.info("Disabling PyTorch because USE_TF is set") | |
_torch_available = False | |
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: | |
_tf_available = importlib.util.find_spec("tensorflow") is not None | |
if _tf_available: | |
candidates = ( | |
"tensorflow", | |
"tensorflow-cpu", | |
"tensorflow-gpu", | |
"tf-nightly", | |
"tf-nightly-cpu", | |
"tf-nightly-gpu", | |
"intel-tensorflow", | |
"tensorflow-rocm", | |
"tensorflow-macos", | |
) | |
_tf_version = None | |
# For the metadata, we have to look for both tensorflow and tensorflow-cpu | |
for pkg in candidates: | |
try: | |
_tf_version = importlib.metadata.version(pkg) | |
break | |
except importlib.metadata.PackageNotFoundError: | |
pass | |
_tf_available = _tf_version is not None | |
if _tf_available: | |
if int(_tf_version.split(".")[0]) < 2: # type: ignore[union-attr] # pragma: no cover | |
logging.info(f"TensorFlow found but with version {_tf_version}. DocTR requires version 2 minimum.") | |
_tf_available = False | |
else: | |
logging.info(f"TensorFlow version {_tf_version} available.") | |
else: # pragma: no cover | |
logging.info("Disabling Tensorflow because USE_TORCH is set") | |
_tf_available = False | |
if not _torch_available and not _tf_available: # pragma: no cover | |
raise ModuleNotFoundError( | |
"DocTR requires either TensorFlow or PyTorch to be installed. Please ensure one of them" | |
" is installed and that either USE_TF or USE_TORCH is enabled." | |
) | |
def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover | |
""" | |
package requirement helper | |
Args: | |
---- | |
name: name of the package | |
extra_message: additional message to display if the package is not found | |
""" | |
try: | |
_pkg_version = importlib.metadata.version(name) | |
logging.info(f"{name} version {_pkg_version} available.") | |
except importlib.metadata.PackageNotFoundError: | |
raise ImportError( | |
f"\n\n{extra_message if extra_message is not None else ''} " | |
f"\nPlease install it with the following command: pip install {name}\n" | |
) | |
def is_torch_available(): | |
"""Whether PyTorch is installed.""" | |
return _torch_available | |
def is_tf_available(): | |
"""Whether TensorFlow is installed.""" | |
return _tf_available | |