"""Helper functions for a standard streaming compression API""" from zipfile import ZipFile import fsspec.utils from fsspec.spec import AbstractBufferedFile def noop_file(file, mode, **kwargs): return file # TODO: files should also be available as contexts # should be functions of the form func(infile, mode=, **kwargs) -> file-like compr = {None: noop_file} def register_compression(name, callback, extensions, force=False): """Register an "inferable" file compression type. Registers transparent file compression type for use with fsspec.open. Compression can be specified by name in open, or "infer"-ed for any files ending with the given extensions. Args: name: (str) The compression type name. Eg. "gzip". callback: A callable of form (infile, mode, **kwargs) -> file-like. Accepts an input file-like object, the target mode and kwargs. Returns a wrapped file-like object. extensions: (str, Iterable[str]) A file extension, or list of file extensions for which to infer this compression scheme. Eg. "gz". force: (bool) Force re-registration of compression type or extensions. Raises: ValueError: If name or extensions already registered, and not force. """ if isinstance(extensions, str): extensions = [extensions] # Validate registration if name in compr and not force: raise ValueError(f"Duplicate compression registration: {name}") for ext in extensions: if ext in fsspec.utils.compressions and not force: raise ValueError(f"Duplicate compression file extension: {ext} ({name})") compr[name] = callback for ext in extensions: fsspec.utils.compressions[ext] = name def unzip(infile, mode="rb", filename=None, **kwargs): if "r" not in mode: filename = filename or "file" z = ZipFile(infile, mode="w", **kwargs) fo = z.open(filename, mode="w") fo.close = lambda closer=fo.close: closer() or z.close() return fo z = ZipFile(infile) if filename is None: filename = z.namelist()[0] return z.open(filename, mode="r", **kwargs) register_compression("zip", unzip, "zip") try: from bz2 import BZ2File except ImportError: pass else: register_compression("bz2", BZ2File, "bz2") try: # pragma: no cover from isal import igzip def isal(infile, mode="rb", **kwargs): return igzip.IGzipFile(fileobj=infile, mode=mode, **kwargs) register_compression("gzip", isal, "gz") except ImportError: from gzip import GzipFile register_compression( "gzip", lambda f, **kwargs: GzipFile(fileobj=f, **kwargs), "gz" ) try: from lzma import LZMAFile register_compression("lzma", LZMAFile, "lzma") register_compression("xz", LZMAFile, "xz") except ImportError: pass try: import lzmaffi register_compression("lzma", lzmaffi.LZMAFile, "lzma", force=True) register_compression("xz", lzmaffi.LZMAFile, "xz", force=True) except ImportError: pass class SnappyFile(AbstractBufferedFile): def __init__(self, infile, mode, **kwargs): import snappy super().__init__( fs=None, path="snappy", mode=mode.strip("b") + "b", size=999999999, **kwargs ) self.infile = infile if "r" in mode: self.codec = snappy.StreamDecompressor() else: self.codec = snappy.StreamCompressor() def _upload_chunk(self, final=False): self.buffer.seek(0) out = self.codec.add_chunk(self.buffer.read()) self.infile.write(out) return True def seek(self, loc, whence=0): raise NotImplementedError("SnappyFile is not seekable") def seekable(self): return False def _fetch_range(self, start, end): """Get the specified set of bytes from remote""" data = self.infile.read(end - start) return self.codec.decompress(data) try: import snappy snappy.compress(b"") # Snappy may use the .sz file extension, but this is not part of the # standard implementation. register_compression("snappy", SnappyFile, []) except (ImportError, NameError, AttributeError): pass try: import lz4.frame register_compression("lz4", lz4.frame.open, "lz4") except ImportError: pass try: import zstandard as zstd def zstandard_file(infile, mode="rb"): if "r" in mode: cctx = zstd.ZstdDecompressor() return cctx.stream_reader(infile) else: cctx = zstd.ZstdCompressor(level=10) return cctx.stream_writer(infile) register_compression("zstd", zstandard_file, "zst") except ImportError: pass def available_compressions(): """Return a list of the implemented compressions.""" return list(compr)