File size: 794 Bytes
1865436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import pytorch_lightning as pl
from src.ss.det_models.backbone import initialize_model

class POIDetection(pl.LightningModule):
    def __init__(self,
                 n_classes,
                 **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.model, _ = initialize_model(kwargs["backbone"], 
                                         n_classes, 
                                         tune_only=kwargs["tune_fc_only"])
        
    def forward(self, images, targets=None):
        images = list(image for image in images)
        if targets is not None :
            targets = [{k: v for k, v in t.items()} for t in targets]
            outputs = self.model(images, targets)
        else:
            outputs = self.model(images)
        return outputs