nielsr HF staff commited on
Commit
67a6282
·
1 Parent(s): 4ebe060

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import PerceiverForOpticalFlow
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import requests
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import itertools
10
+ import math
11
+ import cv2
12
+
13
+ model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver")
14
+ TRAIN_SIZE = model.config.train_size
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model.to(device)
17
+
18
+ def normalize(im):
19
+ return im / 255.0 * 2 - 1
20
+
21
+ # source: https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/9
22
+ def extract_image_patches(x, kernel, stride=1, dilation=1):
23
+ # Do TF 'SAME' Padding
24
+ b,c,h,w = x.shape
25
+ h2 = math.ceil(h / stride)
26
+ w2 = math.ceil(w / stride)
27
+ pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
28
+ pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
29
+ x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))
30
+
31
+ # Extract patches
32
+ patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
33
+ patches = patches.permute(0,4,5,1,2,3).contiguous()
34
+
35
+ return patches.view(b,-1,patches.shape[-2], patches.shape[-1])
36
+
37
+ def compute_optical_flow(model, img1, img2, grid_indices, FLOW_SCALE_FACTOR = 20):
38
+ """Function to compute optical flow between two images.
39
+
40
+ To compute the flow between images of arbitrary sizes, we divide the image
41
+ into patches, compute the flow for each patch, and stitch the flows together.
42
+
43
+ Args:
44
+ model: PyTorch Perceiver model
45
+ img1: first image
46
+ img2: second image
47
+ grid_indices: indices of the upper left corner for each patch.
48
+ """
49
+ img1 = torch.tensor(np.moveaxis(img1, -1, 0))
50
+ img2 = torch.tensor(np.moveaxis(img2, -1, 0))
51
+ imgs = torch.stack([img1, img2], dim=0)[None]
52
+ height = imgs.shape[-2]
53
+ width = imgs.shape[-1]
54
+
55
+ patch_size = model.config.train_size
56
+
57
+ if height < patch_size[0]:
58
+ raise ValueError(
59
+ f"Height of image (shape: {imgs.shape}) must be at least {patch_size[0]}."
60
+ "Please pad or resize your image to the minimum dimension."
61
+ )
62
+ if width < patch_size[1]:
63
+ raise ValueError(
64
+ f"Width of image (shape: {imgs.shape}) must be at least {patch_size[1]}."
65
+ "Please pad or resize your image to the minimum dimension."
66
+ )
67
+
68
+ flows = 0
69
+ flow_count = 0
70
+
71
+ for y, x in grid_indices:
72
+ imgs = torch.stack([img1, img2], dim=0)[None]
73
+ inp_piece = imgs[..., y : y + patch_size[0],
74
+ x : x + patch_size[1]]
75
+
76
+ batch_size, _, C, H, W = inp_piece.shape
77
+ patches = extract_image_patches(inp_piece.view(batch_size*2,C,H,W), kernel=3)
78
+ _, C, H, W = patches.shape
79
+ patches = patches.view(batch_size, -1, C, H, W).float().to(model.device)
80
+
81
+ # actual forward pass
82
+ with torch.no_grad():
83
+ output = model(inputs=patches).logits * FLOW_SCALE_FACTOR
84
+
85
+ # the code below could also be implemented in PyTorch
86
+ flow_piece = output.cpu().detach().numpy()
87
+
88
+ weights_x, weights_y = np.meshgrid(
89
+ torch.arange(patch_size[1]), torch.arange(patch_size[0]))
90
+
91
+ weights_x = np.minimum(weights_x + 1, patch_size[1] - weights_x)
92
+ weights_y = np.minimum(weights_y + 1, patch_size[0] - weights_y)
93
+ weights = np.minimum(weights_x, weights_y)[np.newaxis, :, :,
94
+ np.newaxis]
95
+ padding = [(0, 0), (y, height - y - patch_size[0]),
96
+ (x, width - x - patch_size[1]), (0, 0)]
97
+ flows += np.pad(flow_piece * weights, padding)
98
+ flow_count += np.pad(weights, padding)
99
+
100
+ # delete activations to avoid OOM
101
+ del output
102
+
103
+ flows /= flow_count
104
+ return flows
105
+
106
+ def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20):
107
+ if min_overlap >= TRAIN_SIZE[0] or min_overlap >= TRAIN_SIZE[1]:
108
+ raise ValueError(
109
+ f"Overlap should be less than size of patch (got {min_overlap}"
110
+ f"for patch size {patch_size}).")
111
+ ys = list(range(0, image_shape[0], TRAIN_SIZE[0] - min_overlap))
112
+ xs = list(range(0, image_shape[1], TRAIN_SIZE[1] - min_overlap))
113
+ # Make sure the final patch is flush with the image boundary
114
+ ys[-1] = image_shape[0] - patch_size[0]
115
+ xs[-1] = image_shape[1] - patch_size[1]
116
+ return itertools.product(ys, xs)
117
+
118
+ def return_flow(flow):
119
+ flow = np.array(flow)
120
+ # Use Hue, Saturation, Value colour model
121
+ hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
122
+ hsv[..., 2] = 255
123
+
124
+ mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
125
+ hsv[..., 0] = ang / np.pi / 2 * 180
126
+ hsv[..., 1] = np.clip(mag * 255 / 24, 0, 255)
127
+ bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
128
+ return Image.fromarray(bgr)
129
+
130
+ # load image examples
131
+ urls = ["https://storage.googleapis.com/perceiver_io/sintel_frame1.png", "https://storage.googleapis.com/perceiver_io/sintel_frame2.png"]
132
+
133
+ for idx, url in enumerate(urls):
134
+ image = Image.open(requests.get(url, stream=True).raw)
135
+ image.save(f"image_{idx}.png")
136
+
137
+ def process_images(image1, image2):
138
+ im1 = np.array(image1)
139
+ im2 = np.array(image2)
140
+
141
+ # Divide images into patches, compute flow between corresponding patches
142
+ # of both images, and stitch the flows together
143
+ grid_indices = compute_grid_indices(im1.shape)
144
+ output = compute_optical_flow(model, normalize(im1), normalize(im2), grid_indices)
145
+
146
+ # return as PIL Image
147
+ predicted_flow = return_flow(output[0])
148
+ return predicted_flow
149
+
150
+ title = "Interactive demo: Perceiver for optical flow"
151
+ description = "Demo for predicting optical flow with Perceiver IO. To use it, simply upload 2 images (e.g subsequent frames) or use the example images below and click 'submit' to let the model predict the flow of the pixels. Results will show up in a few seconds."
152
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2107.14795'>Perceiver IO: A General Architecture for Structured Inputs & Outputs</a> | <a href='https://deepmind.com/blog/article/building-architectures-that-can-handle-the-worlds-data/'>Official blog</a></p>"
153
+ examples =[[f"image_{idx}.png" for idx in range(len(urls))]]
154
+
155
+ iface = gr.Interface(fn=process_images,
156
+ inputs=[gr.inputs.Image(type="pil"), gr.inputs.Image(type="pil")],
157
+ outputs=gr.outputs.Image(type="pil"),
158
+ title=title,
159
+ description=description,
160
+ article=article,
161
+ examples=examples)
162
+ iface.launch(debug=True)