File size: 5,626 Bytes
c37ceb0
 
 
 
 
 
 
 
 
 
 
 
 
224b9f6
 
c37ceb0
 
c473504
c37ceb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec8f07
d78a0b4
c37ceb0
d78a0b4
 
c37ceb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224b9f6
 
c37ceb0
224b9f6
c37ceb0
 
4a0733d
224b9f6
 
 
 
 
 
 
 
 
 
c37ceb0
 
 
 
3efc2d7
c37ceb0
 
 
 
 
 
 
 
 
 
 
 
 
 
d78a0b4
c37ceb0
 
d78a0b4
c37ceb0
 
 
 
 
 
 
224b9f6
 
5efd297
c37ceb0
e209888
c37ceb0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
import os
import torch
import pytorch_lightning as pl

# torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
# torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
# torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')

# os.system("wget https://github.com/hustvl/YOLOP/raw/main/weights/End-to-end.pth")

from transformers import AutoFeatureExtractor, AutoModelForObjectDetection

from PIL import Image, ImageDraw
import cv2
import matplotlib.pyplot as plt

id2label = {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}

class Detr(pl.LightningModule):

     def __init__(self, lr, weight_decay):
         super().__init__()
         # replace COCO classification head with custom head
         self.model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-small",
                                                             num_labels=len(id2label),
                                                             ignore_mismatched_sizes=True)
         # see https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
         self.lr = lr
         self.weight_decay = weight_decay

     def forward(self, pixel_values):
       outputs = self.model(pixel_values=pixel_values)

       return outputs
     
     def common_step(self, batch, batch_idx):
       pixel_values = batch["pixel_values"]
       labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

       outputs = self.model(pixel_values=pixel_values, labels=labels)

       loss = outputs.loss
       loss_dict = outputs.loss_dict

       return loss, loss_dict

     def training_step(self, batch, batch_idx):
        loss, loss_dict = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        for k,v in loss_dict.items():
          self.log("train_" + k, v.item())

        return loss

     def validation_step(self, batch, batch_idx):
        loss, loss_dict = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss)
        for k,v in loss_dict.items():
          self.log("validation_" + k, v.item())

        return loss

     def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr,
                                  weight_decay=self.weight_decay)
        
        return optimizer


device = "cuda" if torch.cuda.is_available() else "cpu"

feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-small", size=512, max_size=864)

# Build model and load checkpoint
checkpoint = './checkpoints/epoch=1-step=2184.ckpt'
model_yolos = Detr.load_from_checkpoint(checkpoint, lr=2.5e-5, weight_decay=1e-4)

model_yolos.to(device)
model_yolos.eval()

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.756, 0.794, 0.100], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933],
          [0.184, 0.494, 0.741], [0.494, 0.674, 0.556], [0.494, 0.301, 0.933],
          [0.000, 0.325, 0.850], [0.745, 0.301, 0.188]]

# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

def plot_results(pil_img, prob, boxes):

    draw = ImageDraw.Draw(pil_img)
    colors = COLORS * 100
    
    for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
        cl = p.argmax()
        c = tuple(colors[cl])

        draw.rectangle([xmin, ymin, xmax - xmin, ymax - ymin], outline=c, width=2)
        draw.text(
            [xmin + 5, ymin + 5], 
            f'{id2label[cl.item()]}: {p[cl]:0.2f}',
            fill=c)
        # ax.text(xmin, ymin, text, fontsize=10,
        #         bbox=dict(facecolor=c, alpha=0.5))
    return Image.fromarray(pil_img[:,:,::-1])
    # return fig


def generate_preds(processor, model, image):
    inputs = processor(images=image, return_tensors="pt").to(device)
    preds = model(pixel_values=inputs.pixel_values)
    return preds


def visualize_preds(image, preds, threshold=0.9):
    # keep only predictions with confidence >= threshold
    probas = preds.logits.softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > threshold
    
    # convert predicted boxes from [0; 1] to image scales
    bboxes_scaled = rescale_bboxes(preds.pred_boxes[0, keep].cpu(), image.size)

    return plot_results(image, probas[keep], bboxes_scaled)


def detect(img):

    # Run inference
    preds = generate_preds(feature_extractor, model_yolos, img)

    return visualize_preds(img, preds)

   
interface = gr.Interface(
    fn=detect,
    inputs=[gr.Image(type="pil")], 
    outputs=gr.Image(type="pil"),
    # outputs = ['plot'],
    examples=[["./imgs/example1.jpg"], ["./imgs/example2.jpg"]],
    title="YOLOS for traffic object detection",
    description="A downstream application for <a href='https://huggingface.co/docs/transformers/model_doc/yolos' style='text-decoration: underline' target='_blank'>YOLOS</a> which can performe traffic object detection. ")

interface.launch()