Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# @Organization : Alibaba XR-Lab | |
# @Author : Lingteng Qiu | |
# @Email : [email protected] | |
# @Time : 2024-08-30 20:50:27 | |
# @Function : The class defines bbox, base-seg module | |
import copy | |
import cv2 | |
import numpy as np | |
import torch | |
class BaseModel(object): | |
""" | |
Simple BaseModel | |
""" | |
def cuda(self): | |
self.model.cuda() | |
return self | |
def cpu(self): | |
self.model.cpu() | |
return self | |
def float(self): | |
self.model.float() | |
return self | |
def to(self, device): | |
self.model.to(device) | |
return self | |
def eval(self): | |
self.model.eval() | |
return self | |
def train(self): | |
self.model.train() | |
return self | |
def __call__(self, x): | |
raise NotImplementedError | |
def __repr__(self): | |
return f"model: \n{self.model}" | |
def get_dtype_string(arr): | |
if arr.dtype == np.uint8: | |
return "uint8" | |
elif arr.dtype == np.float32: | |
return "float32" | |
elif arr.dtype == np.float64: | |
return "float" | |
else: | |
return "unknow" | |
class BaseSeg(BaseModel): | |
def __init__(self): | |
pass | |
class Bbox: | |
def __init__(self, box, mode="whwh"): | |
assert len(box) == 4 | |
assert mode in ["whwh", "xywh"] | |
self.box = box | |
self.mode = mode | |
def to_xywh(self): | |
if self.mode == "whwh": | |
l, t, r, b = self.box | |
center_x = (l + r) / 2 | |
center_y = (t + b) / 2 | |
width = r - l | |
height = b - t | |
return Bbox([center_x, center_y, width, height], mode="xywh") | |
else: | |
return self | |
def to_whwh(self): | |
if self.mode == "whwh": | |
return self | |
else: | |
cx, cy, w, h = self.box | |
l = cx - w // 2 | |
t = cy - h // 2 | |
r = cx + w - (w // 2) | |
b = cy + h - (h // 2) | |
return Bbox([l, t, r, b], mode="whwh") | |
def area(self): | |
box = self.to_xywh() | |
_, __, w, h = box.box | |
return w * h | |
def get_box(self): | |
return list(map(int, self.box)) | |
def scale(self, scale, width, height): | |
new_box = self.to_xywh() | |
cx, cy, w, h = new_box.get_box() | |
w = w * scale | |
h = h * scale | |
l = cx - w // 2 | |
t = cy - h // 2 | |
r = cx + w - (w // 2) | |
b = cy + h - (h // 2) | |
l = int(max(l, 0)) | |
t = int(max(t, 0)) | |
r = int(min(r, width)) | |
b = int(min(b, height)) | |
return Bbox([l, t, r, b], mode="whwh") | |
def __repr__(self): | |
box = self.to_whwh() | |
l, t, r, b = box.box | |
return f"BBox(left={l}, top={t}, right={r}, bottom={b})" | |
class Image: | |
"""TODO need to debug""" | |
TYPE_ORDER = ["uint8", "float32", "float"] | |
ORDER = ["RGB", "BGR"] | |
MODE = ["numpy"] | |
def __init__(self, input, order="RGB", type_mode="uint8"): | |
"""Only support 3 Channel Image""" | |
if isinstance(input, str): | |
self.data = self.read_image(input, type_mode, order) | |
else: | |
self.data = self.get_image(input, type_mode, order) | |
self.order = order | |
self.type_mode = type_mode | |
def get_image(self, input, type_mode, order): | |
if isinstance(input, Image): | |
return input.to_numpy(type_mode, order) | |
elif isinstance(input, np.ndarray): | |
self.data = input | |
self.order = "RGB" # default | |
self.type_mode = get_dtype_string(input) | |
return self.to_numpy(type_mode, order) | |
else: | |
raise NotImplementedError | |
def to_numpy(self, type_mode="uint8", order="RGB"): | |
data = copy.deepcopy(self.data) | |
if not order == self.order: | |
return data[..., ::-1] # only support RGB -> BGR or BGR -> RGB | |
if self.type_mode == type_mode: | |
return data | |
else: | |
if self.type_mode == "float32": | |
return (self.data / 255.0).astype(np.float32) | |
elif self.type_mode == "float": | |
return (self.data / 255.0).astype(np.float64) | |
def to_tensor(self, order): | |
data = self.to_numpy(type_mode="float32", order=order) | |
return torch.from_numpy(data) | |
def read_image( | |
self, | |
path, | |
mode, | |
order, | |
): | |
"""read an image file into various formats and color mode. | |
Args: | |
path (str): path to the image file. | |
mode (Literal["float", "uint8", "pil", "torch", "tensor"], optional): returned image format. Defaults to "float". | |
float: float32 numpy array, range [0, 1]; | |
uint8: uint8 numpy array, range [0, 255]; | |
pil: PIL image; | |
torch/tensor: float32 torch tensor, range [0, 1]; | |
order (Literal["RGB", "RGBA", "BGR", "BGRA"], optional): channel order. Defaults to "RGB". | |
Note: | |
By default this function will convert RGBA image to white-background RGB image. Use ``order="RGBA"`` to keep the alpha channel. | |
Returns: | |
Union[np.ndarray, PIL.Image, torch.Tensor]: the image array. | |
""" | |
if mode == "pil": | |
return Image.open(path).convert(order) | |
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) | |
# cvtColor | |
if len(img.shape) == 3: # ignore if gray scale | |
if order in ["RGB", "RGBA"]: | |
if img.shape[-1] == 4: | |
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA) | |
elif img.shape[-1] == 3: | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
# mix background | |
if img.shape[-1] == 4 and "A" not in order: | |
img = img.astype(np.float32) / 255 | |
img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:]) | |
# mode | |
if mode == "uint8": | |
if img.dtype != np.uint8: | |
img = (img * 255).astype(np.uint8) | |
elif mode == "float": | |
if img.dtype == np.uint8: | |
img = img.astype(np.float32) / 255 | |
else: | |
raise ValueError(f"Unknown read_image mode {mode}") | |
return img | |