QZFantasies's picture
add wheels
c614b0f
raw
history blame
6.23 kB
# -*- coding: utf-8 -*-
# @Organization : Alibaba XR-Lab
# @Author : Lingteng Qiu
# @Email : [email protected]
# @Time : 2024-08-30 20:50:27
# @Function : The class defines bbox, base-seg module
import copy
import cv2
import numpy as np
import torch
class BaseModel(object):
"""
Simple BaseModel
"""
def cuda(self):
self.model.cuda()
return self
def cpu(self):
self.model.cpu()
return self
def float(self):
self.model.float()
return self
def to(self, device):
self.model.to(device)
return self
def eval(self):
self.model.eval()
return self
def train(self):
self.model.train()
return self
def __call__(self, x):
raise NotImplementedError
def __repr__(self):
return f"model: \n{self.model}"
def get_dtype_string(arr):
if arr.dtype == np.uint8:
return "uint8"
elif arr.dtype == np.float32:
return "float32"
elif arr.dtype == np.float64:
return "float"
else:
return "unknow"
class BaseSeg(BaseModel):
def __init__(self):
pass
class Bbox:
def __init__(self, box, mode="whwh"):
assert len(box) == 4
assert mode in ["whwh", "xywh"]
self.box = box
self.mode = mode
def to_xywh(self):
if self.mode == "whwh":
l, t, r, b = self.box
center_x = (l + r) / 2
center_y = (t + b) / 2
width = r - l
height = b - t
return Bbox([center_x, center_y, width, height], mode="xywh")
else:
return self
def to_whwh(self):
if self.mode == "whwh":
return self
else:
cx, cy, w, h = self.box
l = cx - w // 2
t = cy - h // 2
r = cx + w - (w // 2)
b = cy + h - (h // 2)
return Bbox([l, t, r, b], mode="whwh")
def area(self):
box = self.to_xywh()
_, __, w, h = box.box
return w * h
def get_box(self):
return list(map(int, self.box))
def scale(self, scale, width, height):
new_box = self.to_xywh()
cx, cy, w, h = new_box.get_box()
w = w * scale
h = h * scale
l = cx - w // 2
t = cy - h // 2
r = cx + w - (w // 2)
b = cy + h - (h // 2)
l = int(max(l, 0))
t = int(max(t, 0))
r = int(min(r, width))
b = int(min(b, height))
return Bbox([l, t, r, b], mode="whwh")
def __repr__(self):
box = self.to_whwh()
l, t, r, b = box.box
return f"BBox(left={l}, top={t}, right={r}, bottom={b})"
class Image:
"""TODO need to debug"""
TYPE_ORDER = ["uint8", "float32", "float"]
ORDER = ["RGB", "BGR"]
MODE = ["numpy"]
def __init__(self, input, order="RGB", type_mode="uint8"):
"""Only support 3 Channel Image"""
if isinstance(input, str):
self.data = self.read_image(input, type_mode, order)
else:
self.data = self.get_image(input, type_mode, order)
self.order = order
self.type_mode = type_mode
def get_image(self, input, type_mode, order):
if isinstance(input, Image):
return input.to_numpy(type_mode, order)
elif isinstance(input, np.ndarray):
self.data = input
self.order = "RGB" # default
self.type_mode = get_dtype_string(input)
return self.to_numpy(type_mode, order)
else:
raise NotImplementedError
def to_numpy(self, type_mode="uint8", order="RGB"):
data = copy.deepcopy(self.data)
if not order == self.order:
return data[..., ::-1] # only support RGB -> BGR or BGR -> RGB
if self.type_mode == type_mode:
return data
else:
if self.type_mode == "float32":
return (self.data / 255.0).astype(np.float32)
elif self.type_mode == "float":
return (self.data / 255.0).astype(np.float64)
def to_tensor(self, order):
data = self.to_numpy(type_mode="float32", order=order)
return torch.from_numpy(data)
def read_image(
self,
path,
mode,
order,
):
"""read an image file into various formats and color mode.
Args:
path (str): path to the image file.
mode (Literal["float", "uint8", "pil", "torch", "tensor"], optional): returned image format. Defaults to "float".
float: float32 numpy array, range [0, 1];
uint8: uint8 numpy array, range [0, 255];
pil: PIL image;
torch/tensor: float32 torch tensor, range [0, 1];
order (Literal["RGB", "RGBA", "BGR", "BGRA"], optional): channel order. Defaults to "RGB".
Note:
By default this function will convert RGBA image to white-background RGB image. Use ``order="RGBA"`` to keep the alpha channel.
Returns:
Union[np.ndarray, PIL.Image, torch.Tensor]: the image array.
"""
if mode == "pil":
return Image.open(path).convert(order)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
# cvtColor
if len(img.shape) == 3: # ignore if gray scale
if order in ["RGB", "RGBA"]:
if img.shape[-1] == 4:
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
elif img.shape[-1] == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# mix background
if img.shape[-1] == 4 and "A" not in order:
img = img.astype(np.float32) / 255
img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:])
# mode
if mode == "uint8":
if img.dtype != np.uint8:
img = (img * 255).astype(np.uint8)
elif mode == "float":
if img.dtype == np.uint8:
img = img.astype(np.float32) / 255
else:
raise ValueError(f"Unknown read_image mode {mode}")
return img