Spaces:
Sleeping
Sleeping
# import some common libraries | |
import numpy as np | |
# import some common detectron2 utilities | |
from detectron2 import model_zoo | |
from detectron2.engine import DefaultPredictor | |
from detectron2.config import get_cfg | |
from detectron2.utils.visualizer import Visualizer | |
from detectron2.data import MetadataCatalog, DatasetCatalog | |
from detectron2.utils.visualizer import ColorMode | |
import detectron2.data.transforms as T | |
from predictor import InferenceBase | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from detectron2.data.detection_utils import pil_image_handler | |
# 定义模型类别的常量 | |
class ModelCategory: | |
IMAGE_FEATURE_EXTRACT = "image_feature_extract" | |
IMAGE_CLASSIFICATION = "image_classification" | |
OBJECT_DETECTION = "object_detection" | |
ONE_STEP_OBJECT_DETECTION = "onestep_object_detection" | |
SEMANTIC_SEGMENTATION = "semantic_segmentation" | |
INSTANCE_SEGMENTATION = "instance_segmentation" | |
PANOPTIC_SEGMENTATION = "panoptic_segmentation" | |
KEYPOINTS = "keypoints" | |
REGRESSION = "regression" | |
TEXT_CLASSIFICATION = "text_classification" | |
LANGUAGE_MODELLING = "language_modelling" | |
TRANSLATION = "translation" | |
QA_SYSTEM = "qa_system" | |
RECOMMENDATION_SYSTEM = "recommendation_system" | |
GENERATIVE_MODELLING = "generative_modelling" | |
CONTROL = "control" | |
ROBOTICS = "robotics" | |
YOLO = "yolo" | |
OTHERS = "others" | |
class ModelConfig: | |
cfg: None | |
def __init__(self,model_type, model_path: str=None,cfg_path: str= None,thresh_hold: float = 0.5): | |
self.cfg = get_cfg() | |
if cfg_path is not None: | |
self.cfg.merge_from_file(cfg_path) | |
if model_path is not None: | |
self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(model_path) | |
self.thresh_hold = thresh_hold | |
if model_type == ModelCategory.IMAGE_FEATURE_EXTRACT: | |
self.cfg.TASK_TYPE = "feature" | |
self.cfg.MODEL.WEIGHTS = None | |
elif model_type == ModelCategory.IMAGE_CLASSIFICATION: | |
self.cfg.TASK_TYPE = "classfication" | |
self.cfg.MODEL.WEIGHTS = None | |
elif model_type == ModelCategory.SEMANTIC_SEGMENTATION: | |
self.cfg.TASK_TYPE = "semantic" | |
self.cfg.MODEL.WEIGHTS = None | |
elif model_type == ModelCategory.YOLO: | |
self.cfg.TASK_TYPE = "yolo" | |
self.cfg.MODEL.WEIGHTS = None | |
def get_cfg(self,): | |
return self.cfg | |
class ModelFactory: | |
_instances = {} | |
def __init__(self): | |
self.need_save_images = False | |
def get_instance(cls, category, cfg): | |
if category not in cls._instances: | |
cls._instances[category] = InferenceBase(cfg) | |
return cls._instances[category] | |
def serialize(self,output): | |
serialized = None | |
# print(output) | |
if "instances" in output: | |
serialized = { | |
'image_height': output["instances"].image_size[0], | |
'image_width': output["instances"].image_size[1], | |
'pred_boxes': output["instances"].pred_boxes.tolist() if isinstance(output["instances"].pred_boxes, torch.Tensor) else output["instances"].pred_boxes.tensor.tolist(), | |
'scores': output["instances"].scores.tolist() if output["instances"].has("scores") else None, | |
'pred_classes': output["instances"].pred_classes.tolist() if output["instances"].has("pred_classes") else None | |
} | |
if hasattr(output["instances"],"pred_masks"): | |
# serialized["pred_masks"] = output["instances"].pred_masks.tolist() | |
print("instances.pred_masks",output["instances"].pred_masks.shape) | |
if hasattr(output["instances"],"pred_keypoints"): | |
serialized["pred_keypoints"] = output["instances"].pred_keypoints.tolist() | |
if "sem_seg" in output: | |
# serialized["sem_seg"] = output["sem_seg"].tolist() | |
print("sem_seg:",output["sem_seg"].shape) | |
if "panoptic_seg" in output: | |
print("panoptic_seg:",output["panoptic_seg"][0].shape) | |
# print("panoptic_seg:",output["panoptic_seg"]) | |
serialized["panoptic_seg"] = output["panoptic_seg"][1] | |
if "sem_segs" in output: | |
print("sem_segs:",output["sem_segs"].shape) | |
if "classfication" in output: | |
serialized = [] | |
for item in output["classfication"]: | |
print("classfication:",item["feature"].shape) | |
row = { | |
# "feature": item["feature"].tolist(), | |
"score": item["score"].tolist(), | |
"pred_class": item["pred_class"].tolist(), | |
} | |
serialized.append(row) | |
if "features" in output: | |
print("features:",output["features"].shape) | |
serialized = { | |
"features":output["features"].tolist(), | |
} | |
if serialized is None: | |
return output | |
return serialized | |
def predict(self,pil_image,task_type="panoptic"): | |
result = None | |
vis_output = None | |
if task_type == "panoptic": | |
result,vis_output = self.panoptic_segment(input_image=pil_image) | |
elif task_type == "detect": | |
result,vis_output = self.detect(input_image=pil_image) | |
elif task_type == "classification": | |
result = self.classify(input_image=pil_image) | |
elif task_type == "instance": | |
result,vis_output = self.instance_segment(input_image=pil_image) | |
elif task_type == "semantic": | |
result,vis_output = self.semantic_segment(input_image=pil_image) | |
elif task_type == "feature": | |
result = self.extract(input_image=pil_image) | |
elif task_type == "keypoint": | |
result,vis_output = self.keypoint(input_image=pil_image) | |
elif task_type == "onestep_detect": | |
result,vis_output = self.onstep_detect(input_image=pil_image) | |
elif task_type == "yolo": | |
result,vis_output = self.yolo(input_image=pil_image) | |
return self.serialize(result),vis_output | |
def extract(self, input_image=None,image_path: str="./test.png"): | |
""" | |
Perform classification on an image using Detectron2. | |
""" | |
cfg = ModelConfig(ModelCategory.IMAGE_FEATURE_EXTRACT, | |
model_path=None, | |
cfg_path=None).get_cfg() | |
p = self.get_instance(ModelCategory.IMAGE_FEATURE_EXTRACT,cfg) | |
if input_image is None and image_path is not None: | |
input_image = Image.open(image_path).convert('RGB') | |
input_image = pil_image_handler(input_image) | |
outputs,_ = p.run_on_image(input_image) | |
return outputs | |
def classify(self, input_image=None,image_path: str="./cat.jpg"): | |
""" | |
Perform classification on an image using Detectron2. | |
""" | |
cfg = ModelConfig(ModelCategory.IMAGE_CLASSIFICATION, | |
model_path=None, | |
cfg_path=None).get_cfg() | |
p = self.get_instance(ModelCategory.IMAGE_CLASSIFICATION,cfg) | |
if input_image is None and image_path is not None: | |
input_image = Image.open(image_path).convert('RGB') | |
input_image = pil_image_handler(input_image) | |
outputs,_ = p.run_on_image(input_image) | |
return outputs | |
def onstep_detect(self, input_image=None,image_path: str= "./test.png", confidence_threshold: float = 0.5): | |
""" | |
Perform on step object detection on an image using Detectron2. | |
""" | |
cfg = ModelConfig(ModelCategory.ONE_STEP_OBJECT_DETECTION, | |
model_path="COCO-Detection/retinanet_R_101_FPN_3x.yaml", | |
cfg_path="../configs/COCO-Detection/retinanet_R_101_FPN_3x.yaml").get_cfg() | |
p = self.get_instance(ModelCategory.ONE_STEP_OBJECT_DETECTION,cfg) | |
if input_image is None and image_path is not None: | |
input_image = p.read_image(image_path) | |
else: | |
input_image = pil_image_handler(input_image) | |
outputs,vis_output = p.run_on_image(input_image) | |
return outputs,vis_output | |
def detect(self,input_image=None, image_path: str = "./test.png", confidence_threshold: float = 0.5): | |
""" | |
Perform object detection on an image using Detectron2. | |
""" | |
cfg = ModelConfig(ModelCategory.OBJECT_DETECTION, | |
model_path="COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml", | |
cfg_path="../configs/COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml").get_cfg() | |
p = self.get_instance(ModelCategory.IMAGE_FEATURE_EXTRACT, cfg) | |
if input_image is None and image_path is not None: | |
input_image = p.read_image(image_path) | |
else: | |
input_image = pil_image_handler(input_image) | |
outputs,vis_output = p.run_on_image(input_image) | |
return outputs,vis_output | |
def instance_segment(self,input_image=None, image_path: str="./test.png"): | |
""" | |
Perform instance segmentation on an image using Detectron2. | |
""" | |
cfg = ModelConfig(ModelCategory.INSTANCE_SEGMENTATION, | |
model_path="COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml", | |
cfg_path="../configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml").get_cfg() | |
p = self.get_instance(ModelCategory.INSTANCE_SEGMENTATION,cfg) | |
if input_image is None and image_path is not None: | |
input_image = p.read_image(image_path) | |
else: | |
input_image = pil_image_handler(input_image) | |
outputs,vis_output = p.run_on_image(input_image) | |
return outputs,vis_output | |
def semantic_segment(self,input_image=None, image_path: str="./test.png"): | |
""" | |
Perform instance segmentation on an image using Detectron2. | |
""" | |
cfg = ModelConfig(ModelCategory.SEMANTIC_SEGMENTATION, | |
model_path=None, | |
cfg_path="../configs/PascalVOC-Detection/faster_rcnn_R_50_FPN.yaml").get_cfg() | |
p = self.get_instance(ModelCategory.SEMANTIC_SEGMENTATION,cfg) | |
if input_image is None and image_path is not None: | |
input_image = Image.open(image_path).convert('RGB') | |
input_image = pil_image_handler(input_image) | |
outputs,vis_output = p.run_on_image(input_image) | |
return outputs,vis_output | |
def panoptic_segment(self,input_image=None, image_path: str="./test.png"): | |
""" | |
Perform panoptic segmentation on an image using Detectron2. | |
""" | |
cfg = ModelConfig(ModelCategory.INSTANCE_SEGMENTATION, | |
model_path="COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml", | |
cfg_path="../configs/COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml").get_cfg() | |
p = self.get_instance(ModelCategory.INSTANCE_SEGMENTATION,cfg) | |
if input_image is None and image_path is not None: | |
input_image = p.read_image(image_path) | |
else: | |
input_image = pil_image_handler(input_image) | |
outputs,vis_output = p.run_on_image(input_image) | |
# outputs['sem_seg'] = outputs['sem_seg'].numpy().tolist() | |
return outputs,vis_output | |
def keypoint(self, input_image=None,image_path: str="./test.png"): | |
""" | |
Perform keypoint on an image using Detectron2. | |
""" | |
cfg = ModelConfig(ModelCategory.KEYPOINTS, | |
model_path="COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml", | |
cfg_path="../configs/COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml").get_cfg() | |
p = self.get_instance(ModelCategory.KEYPOINTS,cfg) | |
if input_image is None and image_path is not None: | |
input_image = p.read_image(image_path) | |
else: | |
input_image = pil_image_handler(input_image) | |
outputs,vis_output = p.run_on_image(input_image) | |
return outputs,vis_output | |
def yolo(self, input_image=None,image_path: str="./test/test.png"): | |
cfg = ModelConfig(ModelCategory.YOLO, | |
model_path=None, | |
cfg_path=None).get_cfg() | |
p = self.get_instance(ModelCategory.YOLO,cfg) | |
if input_image is None and image_path is not None: | |
input_image = Image.open(image_path).convert('RGB') | |
input_image = pil_image_handler(input_image) | |
outputs,vis_output = p.run_on_image(input_image) | |
return outputs,vis_output | |
# if __name__ == "__main__": | |
# f = ModelFactory() | |
# # f.prepare_meta() | |
# out = f.yolo() | |
# print(out) | |