Spaces:
Runtime error
Runtime error
import torch | |
from typing import Tuple | |
class TensorImgUtils: | |
def from_to(from_type: list[str], to_type: list[str]): | |
"""Return a function that converts a tensor from one type to another. Args can be lists of strings or just strings (e.g., ["C", "H", "W"] or just "CHW").""" | |
if isinstance(from_type, list): | |
from_type = "".join(from_type) | |
if isinstance(to_type, list): | |
to_type = "".join(to_type) | |
permute_arg = [from_type.index(c) for c in to_type] | |
def convert(tensor: torch.Tensor) -> torch.Tensor: | |
return tensor.permute(permute_arg) | |
return convert | |
def convert_to_type(tensor: torch.Tensor, to_type: str) -> torch.Tensor: | |
"""Convert a tensor to a specific type.""" | |
from_type = TensorImgUtils.identify_type(tensor)[0] | |
if from_type == list(to_type): | |
return tensor | |
if len(from_type) == 4 and len(to_type) == 3: | |
# If converting from a batched tensor to a non-batched tensor, squeeze the batch dimension | |
tensor = tensor.squeeze(0) | |
from_type = from_type[1:] | |
if len(from_type) == 3 and len(to_type) == 4: | |
# If converting from a non-batched tensor to a batched tensor, unsqueeze the batch dimension | |
tensor = tensor.unsqueeze(0) | |
from_type = ["B"] + from_type | |
return TensorImgUtils.from_to(from_type, list(to_type))(tensor) | |
def identify_type(tensor: torch.Tensor) -> Tuple[list[str], str]: | |
"""Identify the type of image tensor. Doesn't currently check for BHW. Returns one of the following:""" | |
dim_n = tensor.dim() | |
if dim_n == 2: | |
return (["H", "W"], "HW") | |
elif dim_n == 3: # HWA, AHW, HWC, or CHW | |
if tensor.size(2) == 3: | |
return (["H", "W", "C"], "HWRGB") | |
elif tensor.size(2) == 4: | |
return (["H", "W", "C"], "HWRGBA") | |
elif tensor.size(0) == 3: | |
return (["C", "H", "W"], "RGBHW") | |
elif tensor.size(0) == 4: | |
return (["C", "H", "W"], "RGBAHW") | |
elif tensor.size(2) == 1: | |
return (["H", "W", "C"], "HWA") | |
elif tensor.size(0) == 1: | |
return (["C", "H", "W"], "AHW") | |
elif dim_n == 4: # BHWC or BCHW | |
if tensor.size(3) >= 3: # BHWRGB or BHWRGBA | |
if tensor.size(3) == 3: | |
return (["B", "H", "W", "C"], "BHWRGB") | |
elif tensor.size(3) == 4: | |
return (["B", "H", "W", "C"], "BHWRGBA") | |
elif tensor.size(1) >= 3: | |
if tensor.size(1) == 3: | |
return (["B", "C", "H", "W"], "BRGBHW") | |
elif tensor.size(1) == 4: | |
return (["B", "C", "H", "W"], "BRGBAHW") | |
else: | |
raise ValueError( | |
f"{dim_n} dimensions is not a valid number of dimensions for an image tensor." | |
) | |
raise ValueError( | |
f"Could not determine shape of Tensor with {dim_n} dimensions and {tensor.shape} shape." | |
) | |
def test_squeeze_batch(tensor: torch.Tensor, strict=False) -> torch.Tensor: | |
# Check if the tensor has a batch dimension (size 4) | |
if tensor.dim() == 4: | |
if tensor.size(0) == 1 or not strict: | |
# If it has a batch dimension with size 1, remove it. It represents a single image. | |
return tensor.squeeze(0) | |
else: | |
raise ValueError( | |
f"This is not a single image. It's a batch of {tensor.size(0)} images." | |
) | |
else: | |
# Otherwise, it doesn't have a batch dimension, so just return the tensor as is. | |
return tensor | |
def test_unsqueeze_batch(tensor: torch.Tensor) -> torch.Tensor: | |
# Check if the tensor has a batch dimension (size 4) | |
if tensor.dim() == 3: | |
# If it doesn't have a batch dimension, add one. It represents a single image. | |
return tensor.unsqueeze(0) | |
else: | |
# Otherwise, it already has a batch dimension, so just return the tensor as is. | |
return tensor | |
def most_pixels(img_tensors: list[torch.Tensor]) -> torch.Tensor: | |
sizes = [ | |
TensorImgUtils.height_width(img)[0] * TensorImgUtils.height_width(img)[1] | |
for img in img_tensors | |
] | |
return img_tensors[sizes.index(max(sizes))] | |
def height_width(image: torch.Tensor) -> Tuple[int, int]: | |
"""Like torchvision.transforms methods, this method assumes Tensor to | |
have [..., H, W] shape, where ... means an arbitrary number of leading | |
dimensions | |
""" | |
return image.shape[-2:] | |
def smaller_axis(image: torch.Tensor) -> int: | |
h, w = TensorImgUtils.height_width(image) | |
return 2 if h < w else 3 | |
def larger_axis(image: torch.Tensor) -> int: | |
h, w = TensorImgUtils.height_width(image) | |
return 2 if h > w else 3 | |