Spaces:
Runtime error
Runtime error
File size: 5,317 Bytes
ffd0e5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
from typing import Tuple
class TensorImgUtils:
@staticmethod
def from_to(from_type: list[str], to_type: list[str]):
"""Return a function that converts a tensor from one type to another. Args can be lists of strings or just strings (e.g., ["C", "H", "W"] or just "CHW")."""
if isinstance(from_type, list):
from_type = "".join(from_type)
if isinstance(to_type, list):
to_type = "".join(to_type)
permute_arg = [from_type.index(c) for c in to_type]
def convert(tensor: torch.Tensor) -> torch.Tensor:
return tensor.permute(permute_arg)
return convert
@staticmethod
def convert_to_type(tensor: torch.Tensor, to_type: str) -> torch.Tensor:
"""Convert a tensor to a specific type."""
from_type = TensorImgUtils.identify_type(tensor)[0]
if from_type == list(to_type):
return tensor
if len(from_type) == 4 and len(to_type) == 3:
# If converting from a batched tensor to a non-batched tensor, squeeze the batch dimension
tensor = tensor.squeeze(0)
from_type = from_type[1:]
if len(from_type) == 3 and len(to_type) == 4:
# If converting from a non-batched tensor to a batched tensor, unsqueeze the batch dimension
tensor = tensor.unsqueeze(0)
from_type = ["B"] + from_type
return TensorImgUtils.from_to(from_type, list(to_type))(tensor)
@staticmethod
def identify_type(tensor: torch.Tensor) -> Tuple[list[str], str]:
"""Identify the type of image tensor. Doesn't currently check for BHW. Returns one of the following:"""
dim_n = tensor.dim()
if dim_n == 2:
return (["H", "W"], "HW")
elif dim_n == 3: # HWA, AHW, HWC, or CHW
if tensor.size(2) == 3:
return (["H", "W", "C"], "HWRGB")
elif tensor.size(2) == 4:
return (["H", "W", "C"], "HWRGBA")
elif tensor.size(0) == 3:
return (["C", "H", "W"], "RGBHW")
elif tensor.size(0) == 4:
return (["C", "H", "W"], "RGBAHW")
elif tensor.size(2) == 1:
return (["H", "W", "C"], "HWA")
elif tensor.size(0) == 1:
return (["C", "H", "W"], "AHW")
elif dim_n == 4: # BHWC or BCHW
if tensor.size(3) >= 3: # BHWRGB or BHWRGBA
if tensor.size(3) == 3:
return (["B", "H", "W", "C"], "BHWRGB")
elif tensor.size(3) == 4:
return (["B", "H", "W", "C"], "BHWRGBA")
elif tensor.size(1) >= 3:
if tensor.size(1) == 3:
return (["B", "C", "H", "W"], "BRGBHW")
elif tensor.size(1) == 4:
return (["B", "C", "H", "W"], "BRGBAHW")
else:
raise ValueError(
f"{dim_n} dimensions is not a valid number of dimensions for an image tensor."
)
raise ValueError(
f"Could not determine shape of Tensor with {dim_n} dimensions and {tensor.shape} shape."
)
@staticmethod
def test_squeeze_batch(tensor: torch.Tensor, strict=False) -> torch.Tensor:
# Check if the tensor has a batch dimension (size 4)
if tensor.dim() == 4:
if tensor.size(0) == 1 or not strict:
# If it has a batch dimension with size 1, remove it. It represents a single image.
return tensor.squeeze(0)
else:
raise ValueError(
f"This is not a single image. It's a batch of {tensor.size(0)} images."
)
else:
# Otherwise, it doesn't have a batch dimension, so just return the tensor as is.
return tensor
@staticmethod
def test_unsqueeze_batch(tensor: torch.Tensor) -> torch.Tensor:
# Check if the tensor has a batch dimension (size 4)
if tensor.dim() == 3:
# If it doesn't have a batch dimension, add one. It represents a single image.
return tensor.unsqueeze(0)
else:
# Otherwise, it already has a batch dimension, so just return the tensor as is.
return tensor
@staticmethod
def most_pixels(img_tensors: list[torch.Tensor]) -> torch.Tensor:
sizes = [
TensorImgUtils.height_width(img)[0] * TensorImgUtils.height_width(img)[1]
for img in img_tensors
]
return img_tensors[sizes.index(max(sizes))]
@staticmethod
def height_width(image: torch.Tensor) -> Tuple[int, int]:
"""Like torchvision.transforms methods, this method assumes Tensor to
have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
"""
return image.shape[-2:]
@staticmethod
def smaller_axis(image: torch.Tensor) -> int:
h, w = TensorImgUtils.height_width(image)
return 2 if h < w else 3
@staticmethod
def larger_axis(image: torch.Tensor) -> int:
h, w = TensorImgUtils.height_width(image)
return 2 if h > w else 3
|