Ahsen Khaliq commited on
Commit
fa33312
·
1 Parent(s): 4f37cba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import requests
4
+ import torchvision.transforms as T
5
+ import matplotlib.pyplot as plt
6
+ from collections import defaultdict
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from skimage.measure import find_contours
10
+
11
+ from matplotlib import patches, lines
12
+ from matplotlib.patches import Polygon
13
+ import gradio as gr
14
+
15
+ torch.set_grad_enabled(False);
16
+ # standard PyTorch mean-std input image normalization
17
+ transform = T.Compose([
18
+ T.Resize(800),
19
+ T.ToTensor(),
20
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
21
+ ])
22
+
23
+ # for output bounding box post-processing
24
+ def box_cxcywh_to_xyxy(x):
25
+ x_c, y_c, w, h = x.unbind(1)
26
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
27
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
28
+ return torch.stack(b, dim=1)
29
+
30
+ def rescale_bboxes(out_bbox, size):
31
+ img_w, img_h = size
32
+ b = box_cxcywh_to_xyxy(out_bbox)
33
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
34
+ return b
35
+ # colors for visualization
36
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
37
+ [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
38
+
39
+ def apply_mask(image, mask, color, alpha=0.5):
40
+ """Apply the given mask to the image.
41
+ """
42
+ for c in range(3):
43
+ image[:, :, c] = np.where(mask == 1,
44
+ image[:, :, c] *
45
+ (1 - alpha) + alpha * color[c] * 255,
46
+ image[:, :, c])
47
+ return image
48
+
49
+ def plot_results(pil_img, scores, boxes, labels, masks=None):
50
+ plt.figure(figsize=(16,10))
51
+ np_image = np.array(pil_img)
52
+ ax = plt.gca()
53
+ colors = COLORS * 100
54
+ if masks is None:
55
+ masks = [None for _ in range(len(scores))]
56
+ assert len(scores) == len(boxes) == len(labels) == len(masks)
57
+ for s, (xmin, ymin, xmax, ymax), l, mask, c in zip(scores, boxes.tolist(), labels, masks, colors):
58
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
59
+ fill=False, color=c, linewidth=3))
60
+ text = f'{l}: {s:0.2f}'
61
+ ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8))
62
+
63
+ if mask is None:
64
+ continue
65
+ np_image = apply_mask(np_image, mask, c)
66
+
67
+ padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)
68
+ padded_mask[1:-1, 1:-1] = mask
69
+ contours = find_contours(padded_mask, 0.5)
70
+ for verts in contours:
71
+ # Subtract the padding and flip (y, x) to (x, y)
72
+ verts = np.fliplr(verts) - 1
73
+ p = Polygon(verts, facecolor="none", edgecolor=c)
74
+ ax.add_patch(p)
75
+
76
+
77
+ plt.imshow(np_image)
78
+ plt.axis('off')
79
+ plt.savefig('foo.png',bbox_inches='tight')
80
+ return 'foo.png'
81
+
82
+
83
+ def add_res(results, ax, color='green'):
84
+ #for tt in results.values():
85
+ if True:
86
+ bboxes = results['boxes']
87
+ labels = results['labels']
88
+ scores = results['scores']
89
+ #keep = scores >= 0.0
90
+ #bboxes = bboxes[keep].tolist()
91
+ #labels = labels[keep].tolist()
92
+ #scores = scores[keep].tolist()
93
+ #print(torchvision.ops.box_iou(tt['boxes'].cpu().detach(), torch.as_tensor([[xmin, ymin, xmax, ymax]])))
94
+
95
+ colors = ['purple', 'yellow', 'red', 'green', 'orange', 'pink']
96
+
97
+ for i, (b, ll, ss) in enumerate(zip(bboxes, labels, scores)):
98
+ ax.add_patch(plt.Rectangle((b[0], b[1]), b[2] - b[0], b[3] - b[1], fill=False, color=colors[i], linewidth=3))
99
+ cls_name = ll if isinstance(ll,str) else CLASSES[ll]
100
+ text = f'{cls_name}: {ss:.2f}'
101
+ print(text)
102
+ ax.text(b[0], b[1], text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8))
103
+ model, postprocessor = torch.hub.load('ashkamath/mdetr:main', 'mdetr_efficientnetB5', pretrained=True, return_postprocessor=True)
104
+ model = model.cpu()
105
+ model.eval();
106
+
107
+
108
+ def plot_inference(im, caption):
109
+ # mean-std normalize the input image (batch-size: 1)
110
+ img = transform(im).unsqueeze(0).cpu()
111
+
112
+ # propagate through the model
113
+ memory_cache = model(img, [caption], encode_and_save=True)
114
+ outputs = model(img, [caption], encode_and_save=False, memory_cache=memory_cache)
115
+
116
+ # keep only predictions with 0.7+ confidence
117
+ probas = 1 - outputs['pred_logits'].softmax(-1)[0, :, -1].cpu()
118
+ keep = (probas > 0.7).cpu()
119
+
120
+ # convert boxes from [0; 1] to image scales
121
+ bboxes_scaled = rescale_bboxes(outputs['pred_boxes'].cpu()[0, keep], im.size)
122
+
123
+ # Extract the text spans predicted by each box
124
+ positive_tokens = (outputs["pred_logits"].cpu()[0, keep].softmax(-1) > 0.1).nonzero().tolist()
125
+ predicted_spans = defaultdict(str)
126
+ for tok in positive_tokens:
127
+ item, pos = tok
128
+ if pos < 255:
129
+ span = memory_cache["tokenized"].token_to_chars(0, pos)
130
+ predicted_spans [item] += " " + caption[span.start:span.end]
131
+
132
+ labels = [predicted_spans [k] for k in sorted(list(predicted_spans .keys()))]
133
+ return plot_results(im, probas[keep], bboxes_scaled, labels)
134
+
135
+
136
+
137
+ title = "Anime2Sketch"
138
+ description = "demo for Anime2Sketch. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
139
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.05703'>Adversarial Open Domain Adaption for Sketch-to-Photo Synthesis</a> | <a href='https://github.com/Mukosame/Anime2Sketch'>Github Repo</a></p>"
140
+
141
+ gr.Interface(
142
+ plot_inference,
143
+ [gr.inputs.Image(type="pil", label="Input"), gr.inputs.Textbox(label="input text")],
144
+ gr.outputs.Image(type="file", label="Output"),
145
+ title=title,
146
+ description=description,
147
+ article=article,
148
+ ).launch(debug=True)