yutyan's picture
Add app
37f5c2f
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
import torch.nn as nn
from .yolo_head import YOLOXHead
from .yolo_pafpn import YOLOPAFPN
class YOLOX(nn.Module):
"""
YOLOX model module. The module list is defined by create_yolov3_modules function.
The network returns loss values from three YOLO layers during training
and detection results during test.
"""
def __init__(self, backbone=None, head=None):
super().__init__()
if backbone is None:
backbone = YOLOPAFPN()
if head is None:
head = YOLOXHead(80)
self.backbone = backbone
self.head = head
def forward(self, x, targets=None):
# fpn output content features of [dark3, dark4, dark5]
fpn_outs = self.backbone(x)
if self.training:
assert targets is not None
loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
fpn_outs, targets, x
)
outputs = {
"total_loss": loss,
"iou_loss": iou_loss,
"l1_loss": l1_loss,
"conf_loss": conf_loss,
"cls_loss": cls_loss,
"num_fg": num_fg,
}
else:
outputs = self.head(fpn_outs)
return outputs
def visualize(self, x, targets, save_prefix="assign_vis_"):
fpn_outs = self.backbone(x)
self.head.visualize_assign_result(fpn_outs, targets, x, save_prefix)