# coding=utf-8 # Copyright 2025 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import json from typing import List, Literal, Optional, Union, cast import requests from .deprecation_utils import deprecate from .import_utils import is_safetensors_available, is_torch_available if is_torch_available(): import torch from ..image_processor import VaeImageProcessor from ..video_processor import VideoProcessor if is_safetensors_available(): import safetensors.torch DTYPE_MAP = { "float16": torch.float16, "float32": torch.float32, "bfloat16": torch.bfloat16, "uint8": torch.uint8, } from PIL import Image def detect_image_type(data: bytes) -> str: if data.startswith(b"\xff\xd8"): return "jpeg" elif data.startswith(b"\x89PNG\r\n\x1a\n"): return "png" elif data.startswith(b"GIF87a") or data.startswith(b"GIF89a"): return "gif" elif data.startswith(b"BM"): return "bmp" return "unknown" def check_inputs_decode( endpoint: str, tensor: "torch.Tensor", processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, do_scaling: bool = True, scaling_factor: Optional[float] = None, shift_factor: Optional[float] = None, output_type: Literal["mp4", "pil", "pt"] = "pil", return_type: Literal["mp4", "pil", "pt"] = "pil", image_format: Literal["png", "jpg"] = "jpg", partial_postprocess: bool = False, input_tensor_type: Literal["binary"] = "binary", output_tensor_type: Literal["binary"] = "binary", height: Optional[int] = None, width: Optional[int] = None, ): if tensor.ndim == 3 and height is None and width is None: raise ValueError("`height` and `width` required for packed latents.") if ( output_type == "pt" and return_type == "pil" and not partial_postprocess and not isinstance(processor, (VaeImageProcessor, VideoProcessor)) ): raise ValueError("`processor` is required.") if do_scaling and scaling_factor is None: deprecate( "do_scaling", "1.0.0", "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.", standard_warn=False, ) def postprocess_decode( response: requests.Response, processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, output_type: Literal["mp4", "pil", "pt"] = "pil", return_type: Literal["mp4", "pil", "pt"] = "pil", partial_postprocess: bool = False, ): if output_type == "pt" or (output_type == "pil" and processor is not None): output_tensor = response.content parameters = response.headers shape = json.loads(parameters["shape"]) dtype = parameters["dtype"] torch_dtype = DTYPE_MAP[dtype] output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape) if output_type == "pt": if partial_postprocess: if return_type == "pil": output = [Image.fromarray(image.numpy()) for image in output_tensor] if len(output) == 1: output = output[0] elif return_type == "pt": output = output_tensor else: if processor is None or return_type == "pt": output = output_tensor else: if isinstance(processor, VideoProcessor): output = cast( List[Image.Image], processor.postprocess_video(output_tensor, output_type="pil")[0], ) else: output = cast( Image.Image, processor.postprocess(output_tensor, output_type="pil")[0], ) elif output_type == "pil" and return_type == "pil" and processor is None: output = Image.open(io.BytesIO(response.content)).convert("RGB") detected_format = detect_image_type(response.content) output.format = detected_format elif output_type == "pil" and processor is not None: if return_type == "pil": output = [ Image.fromarray(image) for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8") ] elif return_type == "pt": output = output_tensor elif output_type == "mp4" and return_type == "mp4": output = response.content return output def prepare_decode( tensor: "torch.Tensor", processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, do_scaling: bool = True, scaling_factor: Optional[float] = None, shift_factor: Optional[float] = None, output_type: Literal["mp4", "pil", "pt"] = "pil", image_format: Literal["png", "jpg"] = "jpg", partial_postprocess: bool = False, height: Optional[int] = None, width: Optional[int] = None, ): headers = {} parameters = { "image_format": image_format, "output_type": output_type, "partial_postprocess": partial_postprocess, "shape": list(tensor.shape), "dtype": str(tensor.dtype).split(".")[-1], } if do_scaling and scaling_factor is not None: parameters["scaling_factor"] = scaling_factor if do_scaling and shift_factor is not None: parameters["shift_factor"] = shift_factor if do_scaling and scaling_factor is None: parameters["do_scaling"] = do_scaling elif do_scaling and scaling_factor is None and shift_factor is None: parameters["do_scaling"] = do_scaling if height is not None and width is not None: parameters["height"] = height parameters["width"] = width headers["Content-Type"] = "tensor/binary" headers["Accept"] = "tensor/binary" if output_type == "pil" and image_format == "jpg" and processor is None: headers["Accept"] = "image/jpeg" elif output_type == "pil" and image_format == "png" and processor is None: headers["Accept"] = "image/png" elif output_type == "mp4": headers["Accept"] = "text/plain" tensor_data = safetensors.torch._tobytes(tensor, "tensor") return {"data": tensor_data, "params": parameters, "headers": headers} def remote_decode( endpoint: str, tensor: "torch.Tensor", processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, do_scaling: bool = True, scaling_factor: Optional[float] = None, shift_factor: Optional[float] = None, output_type: Literal["mp4", "pil", "pt"] = "pil", return_type: Literal["mp4", "pil", "pt"] = "pil", image_format: Literal["png", "jpg"] = "jpg", partial_postprocess: bool = False, input_tensor_type: Literal["binary"] = "binary", output_tensor_type: Literal["binary"] = "binary", height: Optional[int] = None, width: Optional[int] = None, ) -> Union[Image.Image, List[Image.Image], bytes, "torch.Tensor"]: """ Hugging Face Hybrid Inference that allow running VAE decode remotely. Args: endpoint (`str`): Endpoint for Remote Decode. tensor (`torch.Tensor`): Tensor to be decoded. processor (`VaeImageProcessor` or `VideoProcessor`, *optional*): Used with `return_type="pt"`, and `return_type="pil"` for Video models. do_scaling (`bool`, default `True`, *optional*): **DEPRECATED**. **pass `scaling_factor`/`shift_factor` instead.** **still set do_scaling=None/do_scaling=False for no scaling until option is removed** When `True` scaling e.g. `latents / self.vae.config.scaling_factor` is applied remotely. If `False`, input must be passed with scaling applied. scaling_factor (`float`, *optional*): Scaling is applied when passed e.g. [`latents / self.vae.config.scaling_factor`](https://github.com/huggingface/diffusers/blob/7007febae5cff000d4df9059d9cf35133e8b2ca9/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L1083C37-L1083C77). - SD v1: 0.18215 - SD XL: 0.13025 - Flux: 0.3611 If `None`, input must be passed with scaling applied. shift_factor (`float`, *optional*): Shift is applied when passed e.g. `latents + self.vae.config.shift_factor`. - Flux: 0.1159 If `None`, input must be passed with scaling applied. output_type (`"mp4"` or `"pil"` or `"pt", default `"pil"): **Endpoint** output type. Subject to change. Report feedback on preferred type. `"mp4": Supported by video models. Endpoint returns `bytes` of video. `"pil"`: Supported by image and video models. Image models: Endpoint returns `bytes` of an image in `image_format`. Video models: Endpoint returns `torch.Tensor` with partial `postprocessing` applied. Requires `processor` as a flag (any `None` value will work). `"pt"`: Support by image and video models. Endpoint returns `torch.Tensor`. With `partial_postprocess=True` the tensor is postprocessed `uint8` image tensor. Recommendations: `"pt"` with `partial_postprocess=True` is the smallest transfer for full quality. `"pt"` with `partial_postprocess=False` is the most compatible with third party code. `"pil"` with `image_format="jpg"` is the smallest transfer overall. return_type (`"mp4"` or `"pil"` or `"pt", default `"pil"): **Function** return type. `"mp4": Function returns `bytes` of video. `"pil"`: Function returns `PIL.Image.Image`. With `output_type="pil" no further processing is applied. With `output_type="pt" a `PIL.Image.Image` is created. `partial_postprocess=False` `processor` is required. `partial_postprocess=True` `processor` is **not** required. `"pt"`: Function returns `torch.Tensor`. `processor` is **not** required. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized. image_format (`"png"` or `"jpg"`, default `jpg`): Used with `output_type="pil"`. Endpoint returns `jpg` or `png`. partial_postprocess (`bool`, default `False`): Used with `output_type="pt"`. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized. input_tensor_type (`"binary"`, default `"binary"`): Tensor transfer type. output_tensor_type (`"binary"`, default `"binary"`): Tensor transfer type. height (`int`, **optional**): Required for `"packed"` latents. width (`int`, **optional**): Required for `"packed"` latents. Returns: output (`Image.Image` or `List[Image.Image]` or `bytes` or `torch.Tensor`). """ if input_tensor_type == "base64": deprecate( "input_tensor_type='base64'", "1.0.0", "input_tensor_type='base64' is deprecated. Using `binary`.", standard_warn=False, ) input_tensor_type = "binary" if output_tensor_type == "base64": deprecate( "output_tensor_type='base64'", "1.0.0", "output_tensor_type='base64' is deprecated. Using `binary`.", standard_warn=False, ) output_tensor_type = "binary" check_inputs_decode( endpoint, tensor, processor, do_scaling, scaling_factor, shift_factor, output_type, return_type, image_format, partial_postprocess, input_tensor_type, output_tensor_type, height, width, ) kwargs = prepare_decode( tensor=tensor, processor=processor, do_scaling=do_scaling, scaling_factor=scaling_factor, shift_factor=shift_factor, output_type=output_type, image_format=image_format, partial_postprocess=partial_postprocess, height=height, width=width, ) response = requests.post(endpoint, **kwargs) if not response.ok: raise RuntimeError(response.json()) output = postprocess_decode( response=response, processor=processor, output_type=output_type, return_type=return_type, partial_postprocess=partial_postprocess, ) return output def check_inputs_encode( endpoint: str, image: Union["torch.Tensor", Image.Image], scaling_factor: Optional[float] = None, shift_factor: Optional[float] = None, ): pass def postprocess_encode( response: requests.Response, ): output_tensor = response.content parameters = response.headers shape = json.loads(parameters["shape"]) dtype = parameters["dtype"] torch_dtype = DTYPE_MAP[dtype] output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape) return output_tensor def prepare_encode( image: Union["torch.Tensor", Image.Image], scaling_factor: Optional[float] = None, shift_factor: Optional[float] = None, ): headers = {} parameters = {} if scaling_factor is not None: parameters["scaling_factor"] = scaling_factor if shift_factor is not None: parameters["shift_factor"] = shift_factor if isinstance(image, torch.Tensor): data = safetensors.torch._tobytes(image.contiguous(), "tensor") parameters["shape"] = list(image.shape) parameters["dtype"] = str(image.dtype).split(".")[-1] else: buffer = io.BytesIO() image.save(buffer, format="PNG") data = buffer.getvalue() return {"data": data, "params": parameters, "headers": headers} def remote_encode( endpoint: str, image: Union["torch.Tensor", Image.Image], scaling_factor: Optional[float] = None, shift_factor: Optional[float] = None, ) -> "torch.Tensor": """ Hugging Face Hybrid Inference that allow running VAE encode remotely. Args: endpoint (`str`): Endpoint for Remote Decode. image (`torch.Tensor` or `PIL.Image.Image`): Image to be encoded. scaling_factor (`float`, *optional*): Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`]. - SD v1: 0.18215 - SD XL: 0.13025 - Flux: 0.3611 If `None`, input must be passed with scaling applied. shift_factor (`float`, *optional*): Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`. - Flux: 0.1159 If `None`, input must be passed with scaling applied. Returns: output (`torch.Tensor`). """ check_inputs_encode( endpoint, image, scaling_factor, shift_factor, ) kwargs = prepare_encode( image=image, scaling_factor=scaling_factor, shift_factor=shift_factor, ) response = requests.post(endpoint, **kwargs) if not response.ok: raise RuntimeError(response.json()) output = postprocess_encode( response=response, ) return output