nathbotbol's picture
Upload folder using huggingface_hub
d7aea57 verified
from typing import List, Tuple
import pytorch_lightning as pl
import torch
from sacred import Ingredient
from torch import nn
from transformers import AdamW, DetrForObjectDetection, DetrImageProcessor
detr_ingredient = Ingredient("detr", save_git_info=False)
# pylint: disable=unused-variable
@detr_ingredient.config
def config():
hf_ckpt = "facebook/detr-resnet-50"
model_path = "./models/detr.ckpt"
num_label = 3
labels = ["Tableau Electrique", "Disjoncteur", "Bouton de test"]
learning_rate = 1e-5
class DeTrLightning(pl.LightningModule):
@detr_ingredient.capture
def __init__(self, hf_ckpt: str, labels: List[str]):
super().__init__()
self.model = DetrForObjectDetection.from_pretrained(
hf_ckpt,
num_labels=len(labels),
id2label=dict(enumerate(labels)),
label2id={label: i for i, label in enumerate(labels)},
ignore_mismatched_sizes=True,
)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
# pylint: disable=arguments-differ
def training_step(self, batch):
targets = batch["labels"]
output = self.model(**batch)
loss = output["loss"]
self.log("train_loss", loss)
return loss
# pylint: disable=unused-argument
def validation_step(self, batch, batch_id):
inputs, mask, targets = (
batch["pixel_values"],
batch["pixel_mask"],
batch["labels"],
)
output = self.model(inputs, pixel_mask=mask, labels=targets)
loss = output["loss"]
self.log("test_loss", loss)
return loss
# pylint: disable=unused-argument
def test_step(self, batch, batch_id):
inputs, mask, targets = (
batch["pixel_values"],
batch["pixel_mask"],
batch["labels"],
)
output = self.model(inputs, pixel_mask=mask, labels=targets)
loss = output["loss"]
self.log("test_loss", loss)
return loss
# pylint: disable=arguments-differ
@detr_ingredient.capture
def configure_optimizers(self, learning_rate):
optimizer = AdamW(self.parameters(), lr=learning_rate)
return optimizer
@staticmethod
def box_cxcywh_to_xyxy(boxes):
center_x, center_y, width, height = boxes.unbind(-1)
boxes = torch.stack(
# top left x, top left y, bottom right x, bottom right y
[
(center_x - 0.5 * width),
(center_y - 0.5 * height),
(center_x + 0.5 * width),
(center_y + 0.5 * height),
],
dim=-1,
)
return boxes
def process_output(self, model_output, image_size: Tuple[int, int]):
prob = model_output.logits.softmax(-1)
scores, labels = prob.max(-1)
boxes = self.box_cxcywh_to_xyxy(model_output.pred_boxes)
scale_fct = torch.Tensor(
[image_size[0], image_size[1], image_size[0], image_size[1]]
).unsqueeze(0)
boxes = boxes * scale_fct[:, None, :]
return boxes, labels, scores
@detr_ingredient.capture
def get_detr(model_path: str):
if model_path:
return DeTrLightning.load_from_checkpoint(model_path)
return DeTrLightning()
@detr_ingredient.capture
def get_detr_feature_extractor(hf_ckpt):
return DetrImageProcessor.from_pretrained(
hf_ckpt,
)