File size: 5,434 Bytes
8553686 7f8235a 8b3b3ef ecf6aba 8b3b3ef 7f8235a d1477fc 8b3b3ef abc3992 1197f7d 73b88fc 8b3b3ef 1197f7d 8b3b3ef afa32b4 b2baf14 8b3b3ef 1197f7d 8b3b3ef 3e180a7 604c897 c4cd90a 6e46676 8b3b3ef abc3992 b2baf14 8b3b3ef 604c897 73b88fc 8b3b3ef 959b9b0 2ab865c fa548df 8b3b3ef 89a6526 240dcb0 89a6526 240dcb0 8b3b3ef 604c897 8b3b3ef 604c897 4be6676 3092710 8b3b3ef 5fcc6be 8b3b3ef 240dcb0 5fcc6be 8b3b3ef 7f8235a aba5422 7f8235a 3ebbbd9 7f8235a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from math import ceil
from pathlib import Path
from lightning import LightningModule
from torchmetrics.detection import MeanAveragePrecision
from yolo.config.config import Config
from yolo.model.yolo import create_model
from yolo.tools.data_loader import create_dataloader
from yolo.tools.drawer import draw_bboxes
from yolo.tools.loss_functions import create_loss_function
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler
class BaseModel(LightningModule):
def __init__(self, cfg: Config):
super().__init__()
self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
def forward(self, x):
return self.model(x)
class ValidateModel(BaseModel):
def __init__(self, cfg: Config):
super().__init__(cfg)
self.cfg = cfg
if self.cfg.task.task == "validation":
self.validation_cfg = self.cfg.task
else:
self.validation_cfg = self.cfg.task.validation
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", backend="faster_coco_eval")
self.metric.warn_on_many_detections = False
self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
self.ema = self.model
def setup(self, stage):
self.vec2box = create_converter(
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
)
self.post_process = PostProcess(self.vec2box, self.validation_cfg.nms)
def val_dataloader(self):
return self.val_loader
def validation_step(self, batch, batch_idx):
batch_size, images, targets, rev_tensor, img_paths = batch
H, W = images.shape[2:]
predicts = self.post_process(self.ema(images), image_size=[W, H])
self.metric.update(
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
)
return predicts
def on_validation_epoch_end(self):
epoch_metrics = self.metric.compute()
del epoch_metrics["classes"]
self.log_dict(epoch_metrics, prog_bar=True, sync_dist=True, rank_zero_only=True)
self.log_dict(
{"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]},
sync_dist=True,
rank_zero_only=True,
)
self.metric.reset()
class TrainModel(ValidateModel):
def __init__(self, cfg: Config):
super().__init__(cfg)
self.cfg = cfg
self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)
def setup(self, stage):
super().setup(stage)
self.loss_fn = create_loss_function(self.cfg, self.vec2box)
def train_dataloader(self):
return self.train_loader
def on_train_epoch_start(self):
self.trainer.optimizers[0].next_epoch(
ceil(len(self.train_loader) / self.trainer.world_size), self.current_epoch
)
self.vec2box.update(self.cfg.image_size)
def training_step(self, batch, batch_idx):
lr_dict = self.trainer.optimizers[0].next_batch()
batch_size, images, targets, *_ = batch
predicts = self(images)
aux_predicts = self.vec2box(predicts["AUX"])
main_predicts = self.vec2box(predicts["Main"])
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
self.log_dict(
loss_item,
prog_bar=True,
on_epoch=True,
batch_size=batch_size,
rank_zero_only=True,
)
self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, rank_zero_only=True)
return loss * batch_size
def configure_optimizers(self):
optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
return [optimizer], [scheduler]
class InferenceModel(BaseModel):
def __init__(self, cfg: Config):
super().__init__(cfg)
self.cfg = cfg
# TODO: Add FastModel
self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
def setup(self, stage):
self.vec2box = create_converter(
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
)
self.post_process = PostProcess(self.vec2box, self.cfg.task.nms)
def predict_dataloader(self):
return self.predict_loader
def predict_step(self, batch, batch_idx):
images, rev_tensor, origin_frame = batch
predicts = self.post_process(self(images), rev_tensor=rev_tensor)
img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list)
if getattr(self.predict_loader, "is_stream", None):
fps = self._display_stream(img)
else:
fps = None
if getattr(self.cfg.task, "save_predict", None):
self._save_image(img, batch_idx)
return img, fps
def _save_image(self, img, batch_idx):
save_image_path = Path(self.trainer.default_root_dir) / f"frame{batch_idx:03d}.png"
img.save(save_image_path)
print(f"💾 Saved visualize image at {save_image_path}")
|