File size: 7,772 Bytes
72268ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from __future__ import annotations

import functools as _functools
import os as _os
import warnings as _warnings
from glob import glob as _glob

# expose these two exceptions as part of the API.  Everything else should feed into these.
from .exceptions import ConversionError, InvalidArchiveError  # NOQA
from .interface import AbstractBaseFormat
from .tarball import CondaTarBZ2 as _CondaTarBZ2
from .utils import filter_info_files
from .utils import get_executor as _get_executor
from .utils import rm_rf as _rm_rf

SUPPORTED_EXTENSIONS: dict[str, type[AbstractBaseFormat]] = {".tar.bz2": _CondaTarBZ2}

libarchive_enabled = False  #: Old API meaning "can extract .conda" (now without libarchive)

try:
    from .conda_fmt import ZSTD_COMPRESS_LEVEL, ZSTD_COMPRESS_THREADS
    from .conda_fmt import CondaFormat_v2 as _CondaFormat_v2

    SUPPORTED_EXTENSIONS[".conda"] = _CondaFormat_v2

    libarchive_enabled = True

except ImportError:
    _warnings.warn("Install zstandard Python bindings for .conda support")

THREADSAFE_EXTRACT = True  #: Not present in conda-package-handling<2.0.


def _collect_paths(prefix):
    dir_paths, file_paths = [], []
    for dp, dn, filenames in _os.walk(prefix):
        for f in filenames:
            file_paths.append(_os.path.relpath(_os.path.join(dp, f), prefix))
        dir_paths.extend(_os.path.relpath(_os.path.join(dp, _), prefix) for _ in dn)
    file_list = file_paths + [
        dp for dp in dir_paths if not any(f.startswith(dp + _os.sep) for f in file_paths)
    ]
    return file_list


def get_default_extracted_folder(in_file, abspath=True):
    dirname = None
    for ext in SUPPORTED_EXTENSIONS:
        if in_file.endswith(ext):
            dirname = in_file[: -len(ext)]
    if dirname and not abspath:
        dirname = _os.path.basename(dirname)
    return dirname


def extract(fn, dest_dir=None, components=None, prefix=None):
    if dest_dir:
        if _os.path.isabs(dest_dir) and prefix:
            raise ValueError(
                "dest_dir and prefix both provided as abs paths.  If providing both, "
                "prefix can be abspath, but dest dir must be relative (relative to "
                "prefix)"
            )
        if not _os.path.isabs(dest_dir):
            dest_dir = _os.path.normpath(_os.path.join(prefix or _os.getcwd(), dest_dir))
    else:
        dest_dir = _os.path.join(
            prefix or _os.path.dirname(fn),
            get_default_extracted_folder(fn, abspath=False),
        )

    if not _os.path.isdir(dest_dir):
        _os.makedirs(dest_dir)

    for format in SUPPORTED_EXTENSIONS.values():
        if format.supported(fn):
            format.extract(fn, dest_dir, components=components)
            break
    else:
        raise ValueError(
            "Didn't recognize extension for file '{}'.  Supported extensions are: {}".format(
                fn, list(SUPPORTED_EXTENSIONS.keys())
            )
        )


def create(prefix, file_list, out_fn, out_folder=None, **kw):
    if not out_folder:
        out_folder = _os.getcwd()

    # simplify arguments to format.create()
    if _os.path.isabs(out_fn):
        out_folder = _os.path.dirname(out_fn)
        out_fn = _os.path.basename(out_fn)

    if file_list is None:
        file_list = _collect_paths(prefix)
    elif isinstance(file_list, str):
        try:
            with open(file_list) as f:
                data = f.readlines()
            file_list = [_.strip() for _ in data]
        except:
            raise

    out = None
    for format in SUPPORTED_EXTENSIONS.values():
        if format.supported(out_fn):
            try:
                out = format.create(prefix, file_list, out_fn, out_folder, **kw)
                break
            except BaseException as err:
                # don't leave broken files around
                abs_out_fn = _os.path.join(out_folder, out_fn)
                if _os.path.isfile(abs_out_fn):
                    _rm_rf(abs_out_fn)
                raise err
    else:
        raise ValueError(
            "Didn't recognize extension for file '{}'.  Supported extensions are: {}".format(
                out_fn, list(SUPPORTED_EXTENSIONS.keys())
            )
        )

    return out


def _convert(
    fn,
    out_ext,
    out_folder,
    force=False,
    zstd_compress_level=None,
    zstd_compress_threads=None,
    **kw,
):
    # allow package to work in degraded mode when zstandard is not available
    import conda_package_streaming.transmute
    import zstandard

    basename = get_default_extracted_folder(fn, abspath=False)
    from .validate import validate_converted_files_match_streaming

    if not basename:
        return (
            fn,
            "",
            "Input file %s doesn't have a supported extension (%s), skipping it"
            % (fn, SUPPORTED_EXTENSIONS),
        )
    out_fn = str(_os.path.join(out_folder, basename + out_ext))
    errors = ""
    if not _os.path.lexists(out_fn) or force:
        if force and _os.path.lexists(out_fn):
            _os.unlink(out_fn)

        if out_ext == ".conda":
            # ZSTD_COMPRESS_* constants are only defined if we have zstandard
            if zstd_compress_level is None:
                zstd_compress_level = ZSTD_COMPRESS_LEVEL
            if zstd_compress_threads is None:
                zstd_compress_threads = ZSTD_COMPRESS_THREADS

            def compressor():
                return zstandard.ZstdCompressor(
                    level=zstd_compress_level, threads=zstd_compress_threads
                )

            def is_info(filename):
                return filter_info_files([filename], prefix=".") == []

            transmute = _functools.partial(
                conda_package_streaming.transmute.transmute,
                fn,
                out_folder,
                compressor=compressor,
                is_info=is_info,
            )

        else:
            transmute = _functools.partial(
                conda_package_streaming.transmute.transmute_tar_bz2, fn, out_folder
            )

        try:
            transmute()
            result = validate_converted_files_match_streaming(out_fn, fn)
            _, missing_files, mismatching_sizes = result
            if missing_files or mismatching_sizes:
                errors = str(ConversionError(missing_files, mismatching_sizes))
        except BaseException as e:
            # don't leave partial package around
            if _os.path.isfile(out_fn):
                _rm_rf(out_fn)
            if not isinstance(e, Exception):
                raise
            errors = str(e)

    return fn, out_fn, errors


def transmute(in_file, out_ext, out_folder=None, processes=1, **kw):
    if not out_folder:
        out_folder = _os.path.dirname(in_file) or _os.getcwd()

    flist = set(_glob(in_file))
    if in_file.endswith(".tar.bz2"):
        flist = flist - set(_glob(in_file.replace(".tar.bz2", out_ext)))
    elif in_file.endswith(".conda"):
        flist = flist - set(_glob(in_file.replace(".conda", out_ext)))

    failed_files = {}
    with _get_executor(processes) as executor:
        convert_f = _functools.partial(_convert, out_ext=out_ext, out_folder=out_folder, **kw)
        for fn, out_fn, errors in executor.map(convert_f, flist):
            if errors:
                failed_files[fn] = errors
                _rm_rf(out_fn)
    return failed_files


def get_pkg_details(in_file):
    """For the new pkg format, we return the size and hashes of the inner pkg part of the file"""
    for format in SUPPORTED_EXTENSIONS.values():
        if format.supported(in_file):
            details = format.get_pkg_details(in_file)
            break
    else:
        raise ValueError(f"Don't know what to do with file {in_file}")
    return details