Spaces:
Running
Running
File size: 3,467 Bytes
d7aea57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from typing import List, Tuple
import pytorch_lightning as pl
import torch
from sacred import Ingredient
from torch import nn
from transformers import AdamW, DetrForObjectDetection, DetrImageProcessor
detr_ingredient = Ingredient("detr", save_git_info=False)
# pylint: disable=unused-variable
@detr_ingredient.config
def config():
hf_ckpt = "facebook/detr-resnet-50"
model_path = "./models/detr.ckpt"
num_label = 3
labels = ["Tableau Electrique", "Disjoncteur", "Bouton de test"]
learning_rate = 1e-5
class DeTrLightning(pl.LightningModule):
@detr_ingredient.capture
def __init__(self, hf_ckpt: str, labels: List[str]):
super().__init__()
self.model = DetrForObjectDetection.from_pretrained(
hf_ckpt,
num_labels=len(labels),
id2label=dict(enumerate(labels)),
label2id={label: i for i, label in enumerate(labels)},
ignore_mismatched_sizes=True,
)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
# pylint: disable=arguments-differ
def training_step(self, batch):
targets = batch["labels"]
output = self.model(**batch)
loss = output["loss"]
self.log("train_loss", loss)
return loss
# pylint: disable=unused-argument
def validation_step(self, batch, batch_id):
inputs, mask, targets = (
batch["pixel_values"],
batch["pixel_mask"],
batch["labels"],
)
output = self.model(inputs, pixel_mask=mask, labels=targets)
loss = output["loss"]
self.log("test_loss", loss)
return loss
# pylint: disable=unused-argument
def test_step(self, batch, batch_id):
inputs, mask, targets = (
batch["pixel_values"],
batch["pixel_mask"],
batch["labels"],
)
output = self.model(inputs, pixel_mask=mask, labels=targets)
loss = output["loss"]
self.log("test_loss", loss)
return loss
# pylint: disable=arguments-differ
@detr_ingredient.capture
def configure_optimizers(self, learning_rate):
optimizer = AdamW(self.parameters(), lr=learning_rate)
return optimizer
@staticmethod
def box_cxcywh_to_xyxy(boxes):
center_x, center_y, width, height = boxes.unbind(-1)
boxes = torch.stack(
# top left x, top left y, bottom right x, bottom right y
[
(center_x - 0.5 * width),
(center_y - 0.5 * height),
(center_x + 0.5 * width),
(center_y + 0.5 * height),
],
dim=-1,
)
return boxes
def process_output(self, model_output, image_size: Tuple[int, int]):
prob = model_output.logits.softmax(-1)
scores, labels = prob.max(-1)
boxes = self.box_cxcywh_to_xyxy(model_output.pred_boxes)
scale_fct = torch.Tensor(
[image_size[0], image_size[1], image_size[0], image_size[1]]
).unsqueeze(0)
boxes = boxes * scale_fct[:, None, :]
return boxes, labels, scores
@detr_ingredient.capture
def get_detr(model_path: str):
if model_path:
return DeTrLightning.load_from_checkpoint(model_path)
return DeTrLightning()
@detr_ingredient.capture
def get_detr_feature_extractor(hf_ckpt):
return DetrImageProcessor.from_pretrained(
hf_ckpt,
)
|