Spaces:
Running
Running
from typing import List, Tuple | |
import pytorch_lightning as pl | |
import torch | |
from sacred import Ingredient | |
from torch import nn | |
from transformers import AdamW, DetrForObjectDetection, DetrImageProcessor | |
detr_ingredient = Ingredient("detr", save_git_info=False) | |
# pylint: disable=unused-variable | |
def config(): | |
hf_ckpt = "facebook/detr-resnet-50" | |
model_path = "./models/detr.ckpt" | |
num_label = 3 | |
labels = ["Tableau Electrique", "Disjoncteur", "Bouton de test"] | |
learning_rate = 1e-5 | |
class DeTrLightning(pl.LightningModule): | |
def __init__(self, hf_ckpt: str, labels: List[str]): | |
super().__init__() | |
self.model = DetrForObjectDetection.from_pretrained( | |
hf_ckpt, | |
num_labels=len(labels), | |
id2label=dict(enumerate(labels)), | |
label2id={label: i for i, label in enumerate(labels)}, | |
ignore_mismatched_sizes=True, | |
) | |
def forward(self, *args, **kwargs): | |
return self.model(*args, **kwargs) | |
# pylint: disable=arguments-differ | |
def training_step(self, batch): | |
targets = batch["labels"] | |
output = self.model(**batch) | |
loss = output["loss"] | |
self.log("train_loss", loss) | |
return loss | |
# pylint: disable=unused-argument | |
def validation_step(self, batch, batch_id): | |
inputs, mask, targets = ( | |
batch["pixel_values"], | |
batch["pixel_mask"], | |
batch["labels"], | |
) | |
output = self.model(inputs, pixel_mask=mask, labels=targets) | |
loss = output["loss"] | |
self.log("test_loss", loss) | |
return loss | |
# pylint: disable=unused-argument | |
def test_step(self, batch, batch_id): | |
inputs, mask, targets = ( | |
batch["pixel_values"], | |
batch["pixel_mask"], | |
batch["labels"], | |
) | |
output = self.model(inputs, pixel_mask=mask, labels=targets) | |
loss = output["loss"] | |
self.log("test_loss", loss) | |
return loss | |
# pylint: disable=arguments-differ | |
def configure_optimizers(self, learning_rate): | |
optimizer = AdamW(self.parameters(), lr=learning_rate) | |
return optimizer | |
def box_cxcywh_to_xyxy(boxes): | |
center_x, center_y, width, height = boxes.unbind(-1) | |
boxes = torch.stack( | |
# top left x, top left y, bottom right x, bottom right y | |
[ | |
(center_x - 0.5 * width), | |
(center_y - 0.5 * height), | |
(center_x + 0.5 * width), | |
(center_y + 0.5 * height), | |
], | |
dim=-1, | |
) | |
return boxes | |
def process_output(self, model_output, image_size: Tuple[int, int]): | |
prob = model_output.logits.softmax(-1) | |
scores, labels = prob.max(-1) | |
boxes = self.box_cxcywh_to_xyxy(model_output.pred_boxes) | |
scale_fct = torch.Tensor( | |
[image_size[0], image_size[1], image_size[0], image_size[1]] | |
).unsqueeze(0) | |
boxes = boxes * scale_fct[:, None, :] | |
return boxes, labels, scores | |
def get_detr(model_path: str): | |
if model_path: | |
return DeTrLightning.load_from_checkpoint(model_path) | |
return DeTrLightning() | |
def get_detr_feature_extractor(hf_ckpt): | |
return DetrImageProcessor.from_pretrained( | |
hf_ckpt, | |
) | |