from typing import List, Tuple import pytorch_lightning as pl import torch from sacred import Ingredient from torch import nn from transformers import AdamW, DetrForObjectDetection, DetrImageProcessor detr_ingredient = Ingredient("detr", save_git_info=False) # pylint: disable=unused-variable @detr_ingredient.config def config(): hf_ckpt = "facebook/detr-resnet-50" model_path = "./models/detr.ckpt" num_label = 3 labels = ["Tableau Electrique", "Disjoncteur", "Bouton de test"] learning_rate = 1e-5 class DeTrLightning(pl.LightningModule): @detr_ingredient.capture def __init__(self, hf_ckpt: str, labels: List[str]): super().__init__() self.model = DetrForObjectDetection.from_pretrained( hf_ckpt, num_labels=len(labels), id2label=dict(enumerate(labels)), label2id={label: i for i, label in enumerate(labels)}, ignore_mismatched_sizes=True, ) def forward(self, *args, **kwargs): return self.model(*args, **kwargs) # pylint: disable=arguments-differ def training_step(self, batch): targets = batch["labels"] output = self.model(**batch) loss = output["loss"] self.log("train_loss", loss) return loss # pylint: disable=unused-argument def validation_step(self, batch, batch_id): inputs, mask, targets = ( batch["pixel_values"], batch["pixel_mask"], batch["labels"], ) output = self.model(inputs, pixel_mask=mask, labels=targets) loss = output["loss"] self.log("test_loss", loss) return loss # pylint: disable=unused-argument def test_step(self, batch, batch_id): inputs, mask, targets = ( batch["pixel_values"], batch["pixel_mask"], batch["labels"], ) output = self.model(inputs, pixel_mask=mask, labels=targets) loss = output["loss"] self.log("test_loss", loss) return loss # pylint: disable=arguments-differ @detr_ingredient.capture def configure_optimizers(self, learning_rate): optimizer = AdamW(self.parameters(), lr=learning_rate) return optimizer @staticmethod def box_cxcywh_to_xyxy(boxes): center_x, center_y, width, height = boxes.unbind(-1) boxes = torch.stack( # top left x, top left y, bottom right x, bottom right y [ (center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height), ], dim=-1, ) return boxes def process_output(self, model_output, image_size: Tuple[int, int]): prob = model_output.logits.softmax(-1) scores, labels = prob.max(-1) boxes = self.box_cxcywh_to_xyxy(model_output.pred_boxes) scale_fct = torch.Tensor( [image_size[0], image_size[1], image_size[0], image_size[1]] ).unsqueeze(0) boxes = boxes * scale_fct[:, None, :] return boxes, labels, scores @detr_ingredient.capture def get_detr(model_path: str): if model_path: return DeTrLightning.load_from_checkpoint(model_path) return DeTrLightning() @detr_ingredient.capture def get_detr_feature_extractor(hf_ckpt): return DetrImageProcessor.from_pretrained( hf_ckpt, )