Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import PerceiverForOpticalFlow
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import itertools
|
10 |
+
import math
|
11 |
+
import cv2
|
12 |
+
|
13 |
+
model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver")
|
14 |
+
TRAIN_SIZE = model.config.train_size
|
15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
model.to(device)
|
17 |
+
|
18 |
+
def normalize(im):
|
19 |
+
return im / 255.0 * 2 - 1
|
20 |
+
|
21 |
+
# source: https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/9
|
22 |
+
def extract_image_patches(x, kernel, stride=1, dilation=1):
|
23 |
+
# Do TF 'SAME' Padding
|
24 |
+
b,c,h,w = x.shape
|
25 |
+
h2 = math.ceil(h / stride)
|
26 |
+
w2 = math.ceil(w / stride)
|
27 |
+
pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
|
28 |
+
pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
|
29 |
+
x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))
|
30 |
+
|
31 |
+
# Extract patches
|
32 |
+
patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
|
33 |
+
patches = patches.permute(0,4,5,1,2,3).contiguous()
|
34 |
+
|
35 |
+
return patches.view(b,-1,patches.shape[-2], patches.shape[-1])
|
36 |
+
|
37 |
+
def compute_optical_flow(model, img1, img2, grid_indices, FLOW_SCALE_FACTOR = 20):
|
38 |
+
"""Function to compute optical flow between two images.
|
39 |
+
|
40 |
+
To compute the flow between images of arbitrary sizes, we divide the image
|
41 |
+
into patches, compute the flow for each patch, and stitch the flows together.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
model: PyTorch Perceiver model
|
45 |
+
img1: first image
|
46 |
+
img2: second image
|
47 |
+
grid_indices: indices of the upper left corner for each patch.
|
48 |
+
"""
|
49 |
+
img1 = torch.tensor(np.moveaxis(img1, -1, 0))
|
50 |
+
img2 = torch.tensor(np.moveaxis(img2, -1, 0))
|
51 |
+
imgs = torch.stack([img1, img2], dim=0)[None]
|
52 |
+
height = imgs.shape[-2]
|
53 |
+
width = imgs.shape[-1]
|
54 |
+
|
55 |
+
patch_size = model.config.train_size
|
56 |
+
|
57 |
+
if height < patch_size[0]:
|
58 |
+
raise ValueError(
|
59 |
+
f"Height of image (shape: {imgs.shape}) must be at least {patch_size[0]}."
|
60 |
+
"Please pad or resize your image to the minimum dimension."
|
61 |
+
)
|
62 |
+
if width < patch_size[1]:
|
63 |
+
raise ValueError(
|
64 |
+
f"Width of image (shape: {imgs.shape}) must be at least {patch_size[1]}."
|
65 |
+
"Please pad or resize your image to the minimum dimension."
|
66 |
+
)
|
67 |
+
|
68 |
+
flows = 0
|
69 |
+
flow_count = 0
|
70 |
+
|
71 |
+
for y, x in grid_indices:
|
72 |
+
imgs = torch.stack([img1, img2], dim=0)[None]
|
73 |
+
inp_piece = imgs[..., y : y + patch_size[0],
|
74 |
+
x : x + patch_size[1]]
|
75 |
+
|
76 |
+
batch_size, _, C, H, W = inp_piece.shape
|
77 |
+
patches = extract_image_patches(inp_piece.view(batch_size*2,C,H,W), kernel=3)
|
78 |
+
_, C, H, W = patches.shape
|
79 |
+
patches = patches.view(batch_size, -1, C, H, W).float().to(model.device)
|
80 |
+
|
81 |
+
# actual forward pass
|
82 |
+
with torch.no_grad():
|
83 |
+
output = model(inputs=patches).logits * FLOW_SCALE_FACTOR
|
84 |
+
|
85 |
+
# the code below could also be implemented in PyTorch
|
86 |
+
flow_piece = output.cpu().detach().numpy()
|
87 |
+
|
88 |
+
weights_x, weights_y = np.meshgrid(
|
89 |
+
torch.arange(patch_size[1]), torch.arange(patch_size[0]))
|
90 |
+
|
91 |
+
weights_x = np.minimum(weights_x + 1, patch_size[1] - weights_x)
|
92 |
+
weights_y = np.minimum(weights_y + 1, patch_size[0] - weights_y)
|
93 |
+
weights = np.minimum(weights_x, weights_y)[np.newaxis, :, :,
|
94 |
+
np.newaxis]
|
95 |
+
padding = [(0, 0), (y, height - y - patch_size[0]),
|
96 |
+
(x, width - x - patch_size[1]), (0, 0)]
|
97 |
+
flows += np.pad(flow_piece * weights, padding)
|
98 |
+
flow_count += np.pad(weights, padding)
|
99 |
+
|
100 |
+
# delete activations to avoid OOM
|
101 |
+
del output
|
102 |
+
|
103 |
+
flows /= flow_count
|
104 |
+
return flows
|
105 |
+
|
106 |
+
def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20):
|
107 |
+
if min_overlap >= TRAIN_SIZE[0] or min_overlap >= TRAIN_SIZE[1]:
|
108 |
+
raise ValueError(
|
109 |
+
f"Overlap should be less than size of patch (got {min_overlap}"
|
110 |
+
f"for patch size {patch_size}).")
|
111 |
+
ys = list(range(0, image_shape[0], TRAIN_SIZE[0] - min_overlap))
|
112 |
+
xs = list(range(0, image_shape[1], TRAIN_SIZE[1] - min_overlap))
|
113 |
+
# Make sure the final patch is flush with the image boundary
|
114 |
+
ys[-1] = image_shape[0] - patch_size[0]
|
115 |
+
xs[-1] = image_shape[1] - patch_size[1]
|
116 |
+
return itertools.product(ys, xs)
|
117 |
+
|
118 |
+
def return_flow(flow):
|
119 |
+
flow = np.array(flow)
|
120 |
+
# Use Hue, Saturation, Value colour model
|
121 |
+
hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
|
122 |
+
hsv[..., 2] = 255
|
123 |
+
|
124 |
+
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
|
125 |
+
hsv[..., 0] = ang / np.pi / 2 * 180
|
126 |
+
hsv[..., 1] = np.clip(mag * 255 / 24, 0, 255)
|
127 |
+
bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
|
128 |
+
return Image.fromarray(bgr)
|
129 |
+
|
130 |
+
# load image examples
|
131 |
+
urls = ["https://storage.googleapis.com/perceiver_io/sintel_frame1.png", "https://storage.googleapis.com/perceiver_io/sintel_frame2.png"]
|
132 |
+
|
133 |
+
for idx, url in enumerate(urls):
|
134 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
135 |
+
image.save(f"image_{idx}.png")
|
136 |
+
|
137 |
+
def process_images(image1, image2):
|
138 |
+
im1 = np.array(image1)
|
139 |
+
im2 = np.array(image2)
|
140 |
+
|
141 |
+
# Divide images into patches, compute flow between corresponding patches
|
142 |
+
# of both images, and stitch the flows together
|
143 |
+
grid_indices = compute_grid_indices(im1.shape)
|
144 |
+
output = compute_optical_flow(model, normalize(im1), normalize(im2), grid_indices)
|
145 |
+
|
146 |
+
# return as PIL Image
|
147 |
+
predicted_flow = return_flow(output[0])
|
148 |
+
return predicted_flow
|
149 |
+
|
150 |
+
title = "Interactive demo: Perceiver for optical flow"
|
151 |
+
description = "Demo for predicting optical flow with Perceiver IO. To use it, simply upload 2 images (e.g subsequent frames) or use the example images below and click 'submit' to let the model predict the flow of the pixels. Results will show up in a few seconds."
|
152 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2107.14795'>Perceiver IO: A General Architecture for Structured Inputs & Outputs</a> | <a href='https://deepmind.com/blog/article/building-architectures-that-can-handle-the-worlds-data/'>Official blog</a></p>"
|
153 |
+
examples =[[f"image_{idx}.png" for idx in range(len(urls))]]
|
154 |
+
|
155 |
+
iface = gr.Interface(fn=process_images,
|
156 |
+
inputs=[gr.inputs.Image(type="pil"), gr.inputs.Image(type="pil")],
|
157 |
+
outputs=gr.outputs.Image(type="pil"),
|
158 |
+
title=title,
|
159 |
+
description=description,
|
160 |
+
article=article,
|
161 |
+
examples=examples)
|
162 |
+
iface.launch(debug=True)
|