File size: 3,738 Bytes
153628e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py

import importlib.metadata
import importlib.util
import logging
import os
from typing import Optional

CLASS_NAME: str = "words"


__all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"]

ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})

USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()


if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
    _torch_available = importlib.util.find_spec("torch") is not None
    if _torch_available:
        try:
            _torch_version = importlib.metadata.version("torch")
            logging.info(f"PyTorch version {_torch_version} available.")
        except importlib.metadata.PackageNotFoundError:  # pragma: no cover
            _torch_available = False
else:  # pragma: no cover
    logging.info("Disabling PyTorch because USE_TF is set")
    _torch_available = False


if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
    _tf_available = importlib.util.find_spec("tensorflow") is not None
    if _tf_available:
        candidates = (
            "tensorflow",
            "tensorflow-cpu",
            "tensorflow-gpu",
            "tf-nightly",
            "tf-nightly-cpu",
            "tf-nightly-gpu",
            "intel-tensorflow",
            "tensorflow-rocm",
            "tensorflow-macos",
        )
        _tf_version = None
        # For the metadata, we have to look for both tensorflow and tensorflow-cpu
        for pkg in candidates:
            try:
                _tf_version = importlib.metadata.version(pkg)
                break
            except importlib.metadata.PackageNotFoundError:
                pass
        _tf_available = _tf_version is not None
    if _tf_available:
        if int(_tf_version.split(".")[0]) < 2:  # type: ignore[union-attr]  # pragma: no cover
            logging.info(f"TensorFlow found but with version {_tf_version}. DocTR requires version 2 minimum.")
            _tf_available = False
        else:
            logging.info(f"TensorFlow version {_tf_version} available.")
else:  # pragma: no cover
    logging.info("Disabling Tensorflow because USE_TORCH is set")
    _tf_available = False


if not _torch_available and not _tf_available:  # pragma: no cover
    raise ModuleNotFoundError(
        "DocTR requires either TensorFlow or PyTorch to be installed. Please ensure one of them"
        " is installed and that either USE_TF or USE_TORCH is enabled."
    )


def requires_package(name: str, extra_message: Optional[str] = None) -> None:  # pragma: no cover
    """
    package requirement helper

    Args:
    ----
        name: name of the package
        extra_message: additional message to display if the package is not found
    """
    try:
        _pkg_version = importlib.metadata.version(name)
        logging.info(f"{name} version {_pkg_version} available.")
    except importlib.metadata.PackageNotFoundError:
        raise ImportError(
            f"\n\n{extra_message if extra_message is not None else ''} "
            f"\nPlease install it with the following command: pip install {name}\n"
        )


def is_torch_available():
    """Whether PyTorch is installed."""
    return _torch_available


def is_tf_available():
    """Whether TensorFlow is installed."""
    return _tf_available