Spaces:
Running
Running
"""Helper functions for a standard streaming compression API""" | |
from zipfile import ZipFile | |
import fsspec.utils | |
from fsspec.spec import AbstractBufferedFile | |
def noop_file(file, mode, **kwargs): | |
return file | |
# TODO: files should also be available as contexts | |
# should be functions of the form func(infile, mode=, **kwargs) -> file-like | |
compr = {None: noop_file} | |
def register_compression(name, callback, extensions, force=False): | |
"""Register an "inferable" file compression type. | |
Registers transparent file compression type for use with fsspec.open. | |
Compression can be specified by name in open, or "infer"-ed for any files | |
ending with the given extensions. | |
Args: | |
name: (str) The compression type name. Eg. "gzip". | |
callback: A callable of form (infile, mode, **kwargs) -> file-like. | |
Accepts an input file-like object, the target mode and kwargs. | |
Returns a wrapped file-like object. | |
extensions: (str, Iterable[str]) A file extension, or list of file | |
extensions for which to infer this compression scheme. Eg. "gz". | |
force: (bool) Force re-registration of compression type or extensions. | |
Raises: | |
ValueError: If name or extensions already registered, and not force. | |
""" | |
if isinstance(extensions, str): | |
extensions = [extensions] | |
# Validate registration | |
if name in compr and not force: | |
raise ValueError(f"Duplicate compression registration: {name}") | |
for ext in extensions: | |
if ext in fsspec.utils.compressions and not force: | |
raise ValueError(f"Duplicate compression file extension: {ext} ({name})") | |
compr[name] = callback | |
for ext in extensions: | |
fsspec.utils.compressions[ext] = name | |
def unzip(infile, mode="rb", filename=None, **kwargs): | |
if "r" not in mode: | |
filename = filename or "file" | |
z = ZipFile(infile, mode="w", **kwargs) | |
fo = z.open(filename, mode="w") | |
fo.close = lambda closer=fo.close: closer() or z.close() | |
return fo | |
z = ZipFile(infile) | |
if filename is None: | |
filename = z.namelist()[0] | |
return z.open(filename, mode="r", **kwargs) | |
register_compression("zip", unzip, "zip") | |
try: | |
from bz2 import BZ2File | |
except ImportError: | |
pass | |
else: | |
register_compression("bz2", BZ2File, "bz2") | |
try: # pragma: no cover | |
from isal import igzip | |
def isal(infile, mode="rb", **kwargs): | |
return igzip.IGzipFile(fileobj=infile, mode=mode, **kwargs) | |
register_compression("gzip", isal, "gz") | |
except ImportError: | |
from gzip import GzipFile | |
register_compression( | |
"gzip", lambda f, **kwargs: GzipFile(fileobj=f, **kwargs), "gz" | |
) | |
try: | |
from lzma import LZMAFile | |
register_compression("lzma", LZMAFile, "lzma") | |
register_compression("xz", LZMAFile, "xz") | |
except ImportError: | |
pass | |
try: | |
import lzmaffi | |
register_compression("lzma", lzmaffi.LZMAFile, "lzma", force=True) | |
register_compression("xz", lzmaffi.LZMAFile, "xz", force=True) | |
except ImportError: | |
pass | |
class SnappyFile(AbstractBufferedFile): | |
def __init__(self, infile, mode, **kwargs): | |
import snappy | |
super().__init__( | |
fs=None, path="snappy", mode=mode.strip("b") + "b", size=999999999, **kwargs | |
) | |
self.infile = infile | |
if "r" in mode: | |
self.codec = snappy.StreamDecompressor() | |
else: | |
self.codec = snappy.StreamCompressor() | |
def _upload_chunk(self, final=False): | |
self.buffer.seek(0) | |
out = self.codec.add_chunk(self.buffer.read()) | |
self.infile.write(out) | |
return True | |
def seek(self, loc, whence=0): | |
raise NotImplementedError("SnappyFile is not seekable") | |
def seekable(self): | |
return False | |
def _fetch_range(self, start, end): | |
"""Get the specified set of bytes from remote""" | |
data = self.infile.read(end - start) | |
return self.codec.decompress(data) | |
try: | |
import snappy | |
snappy.compress(b"") | |
# Snappy may use the .sz file extension, but this is not part of the | |
# standard implementation. | |
register_compression("snappy", SnappyFile, []) | |
except (ImportError, NameError, AttributeError): | |
pass | |
try: | |
import lz4.frame | |
register_compression("lz4", lz4.frame.open, "lz4") | |
except ImportError: | |
pass | |
try: | |
import zstandard as zstd | |
def zstandard_file(infile, mode="rb"): | |
if "r" in mode: | |
cctx = zstd.ZstdDecompressor() | |
return cctx.stream_reader(infile) | |
else: | |
cctx = zstd.ZstdCompressor(level=10) | |
return cctx.stream_writer(infile) | |
register_compression("zstd", zstandard_file, "zst") | |
except ImportError: | |
pass | |
def available_compressions(): | |
"""Return a list of the implemented compressions.""" | |
return list(compr) | |