|
import spaces |
|
import cv2 |
|
import numpy as np |
|
import gradio as gr |
|
import cwm.utils as utils |
|
import os |
|
os.system("pip uninstall -y gradio") |
|
os.system("pip install gradio==4.31.0") |
|
|
|
|
|
arrow_color = (0, 255, 0) |
|
dot_color = (0, 255, 0) |
|
dot_color_fixed = (255, 0, 0) |
|
thickness = 3 |
|
tip_length = 0.3 |
|
dot_radius = 7 |
|
dot_thickness = -1 |
|
from PIL import Image |
|
import torch |
|
import json |
|
|
|
from cwm.model.model_factory import model_factory |
|
|
|
from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
model = model_factory.load_model('vitb_8x8patch_2frames_encoder_mask_token') |
|
|
|
model.requires_grad_(False) |
|
model.eval() |
|
|
|
model = model |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
from matplotlib.patches import FancyArrowPatch |
|
from PIL import Image |
|
import numpy as np |
|
|
|
from torchvision import transforms |
|
|
|
def draw_arrows_matplotlib(img, selected_points, zero_length): |
|
""" |
|
Draw arrows on the image using matplotlib for better quality arrows and dots. |
|
""" |
|
fig, ax = plt.subplots() |
|
ax.imshow(img) |
|
|
|
for i in range(0, len(selected_points), 2): |
|
start_point = selected_points[i] |
|
end_point = selected_points[i + 1] |
|
|
|
if start_point == end_point or zero_length: |
|
|
|
ax.scatter(start_point[0], start_point[1], color='red', s=100) |
|
else: |
|
|
|
arrow = FancyArrowPatch((start_point[0], start_point[1]), (end_point[0], end_point[1]), |
|
color='green', linewidth=2, arrowstyle='->', mutation_scale=15) |
|
ax.add_patch(arrow) |
|
|
|
|
|
ax.scatter(start_point[0], start_point[1], color='green', s=100) |
|
ax.scatter(end_point[0], end_point[1], color='green', s=100) |
|
|
|
|
|
fig.canvas.draw() |
|
img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
|
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
plt.close(fig) |
|
return img_array |
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_c(x, points): |
|
x = utils.imagenet_normalize(x) |
|
with torch.no_grad(): |
|
counterfactual = model.get_counterfactual(x, points) |
|
return counterfactual |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown('''# Scene editing interventions with Counterfactual World Models! |
|
''') |
|
|
|
|
|
with gr.Tab(label='Image'): |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
original_image = gr.State(value=None) |
|
original_image_high_res = gr.State(value=None) |
|
input_image = gr.Image(type="numpy", label="Upload Image") |
|
|
|
|
|
selected_points = gr.State([]) |
|
zero_length_toggle = gr.Checkbox(label="Select patches to be kept fixed", value=False) |
|
with gr.Row(): |
|
gr.Markdown('1. **Click on the image** to specify patch motion by selecting a start and end point. \n 2. After selecting the points to move, enable the **"Select patches to be kept fixed"** checkbox to choose a few points to keep fixed. \n 3. **Click "Run Model"** to visualize the result of the edit.') |
|
undo_button = gr.Button('Undo last action') |
|
clear_button = gr.Button('Clear All') |
|
|
|
|
|
run_model_button = gr.Button('Run Model') |
|
|
|
|
|
with gr.Tab(label='Intervention'): |
|
output_image = gr.Image(type='numpy') |
|
|
|
|
|
def resize_to_square(img, size=512): |
|
print("Resizing image to square") |
|
img = Image.fromarray(img) |
|
transform = transforms.Compose([ |
|
transforms.Resize((size, size)), |
|
|
|
]) |
|
img = transform(img) |
|
|
|
return np.array(img) |
|
|
|
|
|
def load_img(evt: gr.SelectData): |
|
img_path = evt.value['image']['path'] |
|
img = np.array(Image.open(img_path)) |
|
|
|
with open('./assets/intervention_test_images/annot.json', 'r') as f: |
|
points_json = json.load(f) |
|
|
|
resized_img = resize_to_square(img) |
|
|
|
if os.path.basename(img_path) not in points_json: |
|
return resized_img, resized_img, img, [] |
|
|
|
points_json = points_json[os.path.basename(img_path)] |
|
|
|
|
|
temp = resized_img.copy() |
|
|
|
|
|
for i in range(0, len(points_json), 2): |
|
start_point = points_json[i] |
|
end_point = points_json[i + 1] |
|
if start_point == end_point: |
|
|
|
color = dot_color_fixed |
|
else: |
|
cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length, |
|
line_type=cv2.LINE_AA) |
|
color = arrow_color |
|
|
|
|
|
|
|
cv2.circle(temp, start_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA) |
|
cv2.circle(temp, end_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA) |
|
|
|
|
|
if len(points_json) == 1: |
|
start_point = points_json[0] |
|
cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) |
|
|
|
|
|
|
|
return temp, resized_img, img, points_json |
|
|
|
|
|
def store_img(img): |
|
resized_img = resize_to_square(img) |
|
print(f"Image uploaded with shape: {resized_img.shape}") |
|
return resized_img, resized_img, img, [] |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gallery = gr.Gallery( ["./assets/ducks.jpg", "./assets/robot_arm.jpg", "./assets/bread.jpg", "./assets/bird.jpg", "./assets/desk_1.jpg", "./assets/glasses.jpg", "./assets/watering_pot.jpg"], columns=5, allow_preview=False, label="Select an example image to test") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gallery.select(load_img, outputs=[input_image, original_image, original_image_high_res, selected_points]) |
|
|
|
input_image.upload(store_img, [input_image], [input_image, original_image, original_image_high_res, selected_points]) |
|
|
|
|
|
def get_point(img, sel_pix, zero_length, evt: gr.SelectData): |
|
sel_pix.append(evt.index) |
|
|
|
|
|
if zero_length: |
|
point = sel_pix[-1] |
|
cv2.circle(img, point, dot_radius, dot_color_fixed, dot_thickness, lineType=cv2.LINE_AA) |
|
sel_pix.append(evt.index) |
|
else: |
|
|
|
|
|
if len(sel_pix) % 2 == 1: |
|
|
|
start_point = sel_pix[-1] |
|
cv2.circle(img, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) |
|
|
|
|
|
if len(sel_pix) % 2 == 0: |
|
|
|
start_point = sel_pix[-2] |
|
end_point = sel_pix[-1] |
|
|
|
|
|
cv2.arrowedLine(img, start_point, end_point, arrow_color, thickness, tipLength=tip_length, line_type=cv2.LINE_AA) |
|
|
|
|
|
cv2.circle(img, end_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) |
|
|
|
return img if isinstance(img, np.ndarray) else np.array(img) |
|
|
|
input_image.select(get_point, [input_image, selected_points, zero_length_toggle], [input_image]) |
|
|
|
|
|
def undo_arrows(orig_img, sel_pix, zero_length): |
|
temp = orig_img.copy() |
|
|
|
|
|
|
|
|
|
|
|
if len(sel_pix) >= 2: |
|
sel_pix.pop() |
|
sel_pix.pop() |
|
|
|
|
|
for i in range(0, len(sel_pix), 2): |
|
start_point = sel_pix[i] |
|
end_point = sel_pix[i + 1] |
|
if start_point == end_point: |
|
|
|
color = dot_color_fixed |
|
else: |
|
cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length) |
|
color = arrow_color |
|
|
|
|
|
|
|
cv2.circle(temp, start_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA) |
|
cv2.circle(temp, end_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA) |
|
|
|
|
|
if len(sel_pix) == 1: |
|
start_point = sel_pix[0] |
|
cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) |
|
|
|
return temp if isinstance(temp, np.ndarray) else np.array(temp) |
|
|
|
undo_button.click(undo_arrows, [original_image, selected_points, zero_length_toggle], [input_image]) |
|
|
|
|
|
|
|
def clear_all_points(orig_img, sel_pix): |
|
sel_pix.clear() |
|
return orig_img |
|
|
|
clear_button.click(clear_all_points, [original_image, selected_points], [input_image]) |
|
|
|
|
|
def run_model_on_points(points, input_image, original_image): |
|
H = input_image.shape[0] |
|
W = input_image.shape[1] |
|
factor = 256/H |
|
|
|
points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor |
|
|
|
points = points[:, [1, 0, 3, 2]] |
|
|
|
img = Image.fromarray(original_image) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
|
|
]) |
|
img = np.array(transform(img)) |
|
|
|
|
|
|
|
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 |
|
|
|
img = img[None] |
|
|
|
|
|
x = img[:, :, None].expand(-1, -1, 2, -1, -1) |
|
|
|
|
|
|
|
|
|
counterfactual = get_c(x, points) |
|
|
|
|
|
counterfactual = counterfactual.squeeze() |
|
|
|
counterfactual = counterfactual.clamp(0, 1).permute(1,2,0).detach().cpu().numpy() |
|
|
|
|
|
|
|
|
|
return counterfactual |
|
|
|
|
|
run_model_button.click(run_model_on_points, [selected_points, input_image, original_image_high_res], [output_image]) |
|
|
|
|
|
|
|
|
|
demo.queue().launch(inbrowser=True, share=True) |
|
|