AutoEval / doctr /io /image /pytorch.py
adirathor07's picture
added doctr folder
153628e
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from io import BytesIO
from typing import Tuple
import numpy as np
import torch
from PIL import Image
from torchvision.transforms.functional import to_tensor
from doctr.utils.common_types import AbstractPath
__all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
def tensor_from_pil(pil_img: Image.Image, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Convert a PIL Image to a PyTorch tensor
Args:
----
pil_img: a PIL image
dtype: the output tensor data type
Returns:
-------
decoded image as tensor
"""
if dtype == torch.float32:
img = to_tensor(pil_img)
else:
img = tensor_from_numpy(np.array(pil_img, np.uint8, copy=True), dtype)
return img
def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Read an image file as a PyTorch tensor
Args:
----
img_path: location of the image file
dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
Returns:
-------
decoded image as a tensor
"""
if dtype not in (torch.uint8, torch.float16, torch.float32):
raise ValueError("insupported value for dtype")
with Image.open(img_path, mode="r") as pil_img:
return tensor_from_pil(pil_img.convert("RGB"), dtype)
def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Read a byte stream as a PyTorch tensor
Args:
----
img_content: bytes of a decoded image
dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
Returns:
-------
decoded image as a tensor
"""
if dtype not in (torch.uint8, torch.float16, torch.float32):
raise ValueError("insupported value for dtype")
with Image.open(BytesIO(img_content), mode="r") as pil_img:
return tensor_from_pil(pil_img.convert("RGB"), dtype)
def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Read an image file as a PyTorch tensor
Args:
----
npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8
dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
Returns:
-------
same image as a tensor of shape (C, H, W)
"""
if dtype not in (torch.uint8, torch.float16, torch.float32):
raise ValueError("insupported value for dtype")
if dtype == torch.float32:
img = to_tensor(npy_img)
else:
img = torch.from_numpy(npy_img)
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
if dtype == torch.float16:
# Switch to FP16
img = img.to(dtype=torch.float16).div(255)
return img
def get_img_shape(img: torch.Tensor) -> Tuple[int, int]:
"""Get the shape of an image"""
return img.shape[-2:] # type: ignore[return-value]