sshi commited on
Commit
c37ceb0
1 Parent(s): f9692cc

Add application file

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import pytorch_lightning as pl
5
+
6
+ # torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
7
+ # torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
8
+ # torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')
9
+
10
+ # os.system("wget https://github.com/hustvl/YOLOP/raw/main/weights/End-to-end.pth")
11
+
12
+ from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
13
+
14
+ from PIL import Image
15
+ import matplotlib.pyplot as plt
16
+
17
+
18
+ class Detr(pl.LightningModule):
19
+
20
+ def __init__(self, lr, weight_decay):
21
+ super().__init__()
22
+ # replace COCO classification head with custom head
23
+ self.model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-small",
24
+ num_labels=len(id2label),
25
+ ignore_mismatched_sizes=True)
26
+ # see https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
27
+ self.lr = lr
28
+ self.weight_decay = weight_decay
29
+
30
+ def forward(self, pixel_values):
31
+ outputs = self.model(pixel_values=pixel_values)
32
+
33
+ return outputs
34
+
35
+ def common_step(self, batch, batch_idx):
36
+ pixel_values = batch["pixel_values"]
37
+ labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]
38
+
39
+ outputs = self.model(pixel_values=pixel_values, labels=labels)
40
+
41
+ loss = outputs.loss
42
+ loss_dict = outputs.loss_dict
43
+
44
+ return loss, loss_dict
45
+
46
+ def training_step(self, batch, batch_idx):
47
+ loss, loss_dict = self.common_step(batch, batch_idx)
48
+ # logs metrics for each training_step,
49
+ # and the average across the epoch
50
+ self.log("training_loss", loss)
51
+ for k,v in loss_dict.items():
52
+ self.log("train_" + k, v.item())
53
+
54
+ return loss
55
+
56
+ def validation_step(self, batch, batch_idx):
57
+ loss, loss_dict = self.common_step(batch, batch_idx)
58
+ self.log("validation_loss", loss)
59
+ for k,v in loss_dict.items():
60
+ self.log("validation_" + k, v.item())
61
+
62
+ return loss
63
+
64
+ def configure_optimizers(self):
65
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr,
66
+ weight_decay=self.weight_decay)
67
+
68
+ return optimizer
69
+
70
+
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+
73
+ feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-small", size=512, max_size=864)
74
+
75
+ # Build model and load checkpoint
76
+ checkpoint = 'fintune_traffic_object.ckpt'
77
+ model = Detr.load_from_checkpoint(checkpoint, lr=2.5e-5, weight_decay=1e-4)
78
+
79
+ model.to(device)
80
+ model.eval()
81
+
82
+ # colors for visualization
83
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
84
+ [0.756, 0.794, 0.100], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933],
85
+ [0.184, 0.494, 0.741], [0.494, 0.674, 0.556], [0.494, 0.301, 0.933],
86
+ [0.000, 0.325, 0.850], [0.745, 0.301, 0.188]]
87
+
88
+ id2label = {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}
89
+
90
+ # for output bounding box post-processing
91
+ def box_cxcywh_to_xyxy(x):
92
+ x_c, y_c, w, h = x.unbind(1)
93
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
94
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
95
+ return torch.stack(b, dim=1)
96
+
97
+ def rescale_bboxes(out_bbox, size):
98
+ img_w, img_h = size
99
+ b = box_cxcywh_to_xyxy(out_bbox)
100
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
101
+ return b
102
+
103
+ def plot_results(pil_img, prob, boxes):
104
+ fig = plt.figure(figsize=(16,10))
105
+ plt.imshow(pil_img)
106
+ ax = plt.gca()
107
+ colors = COLORS * 100
108
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
109
+ cl = p.argmax()
110
+ c = colors[cl]
111
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
112
+ fill=False, color=c, linewidth=2))
113
+ text = f'{id2label[cl.item()]}: {p[cl]:0.2f}'
114
+ ax.text(xmin, ymin, text, fontsize=10,
115
+ bbox=dict(facecolor=c, alpha=0.5))
116
+ plt.axis('off')
117
+ return Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())
118
+
119
+
120
+ def generate_preds(processor, model, image):
121
+ inputs = processor(images=image, return_tensors="pt").to(device)
122
+ pixel_values = inputs.pixel_values.unsqueeze(0)
123
+ preds = model(pixel_values=pixel_values)
124
+ return preds
125
+
126
+
127
+ def visualize_preds(image, preds, threshold=0.9):
128
+ # keep only predictions with confidence >= threshold
129
+ probas = preds.logits.softmax(-1)[0, :, :-1]
130
+ keep = probas.max(-1).values > threshold
131
+
132
+ # convert predicted boxes from [0; 1] to image scales
133
+ bboxes_scaled = rescale_bboxes(preds.pred_boxes[0, keep].cpu(), image.size)
134
+
135
+ return plot_results(image, probas[keep], bboxes_scaled)
136
+
137
+
138
+ def detect(img, model):
139
+
140
+ # Run inference
141
+ preds = generate_preds(feature_extractor, model, img)
142
+
143
+ return visualize_preds(img, preds)
144
+
145
+
146
+ interface = gr.Interface(
147
+ fn=detect,
148
+ inputs=[gr.Image(type="pil")],
149
+ outputs=gr.Image(type="pil"),
150
+ # examples=[["example1.jpeg"], ["example2.jpeg"], ["example3.jpeg"]],
151
+ title="YOLOS for traffic object detection",
152
+ description="A downstream application for <a href='https://huggingface.co/docs/transformers/model_doc/yolos' style='text-decoration: underline' target='_blank'>YOLOS</a> application on traffic object detection. ")
153
+
154
+ interface.launch()