import spaces |
import cv2 |
import numpy as np |
import gradio as gr |
import cwm.utils as utils |
import os |
os.system("pip uninstall -y gradio") |
os.system("pip install gradio==4.31.0") |
arrow_color = (0, 255, 0) |
dot_color = (0, 255, 0) |
dot_color_fixed = (255, 0, 0) |
thickness = 3 |
tip_length = 0.3 |
dot_radius = 7 |
dot_thickness = -1 |
from PIL import Image |
import torch |
from cwm.model.model_factory import model_factory |
from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
model = model_factory.load_model('vitb_8x8patch_2frames_encoder_mask_token') |
model.requires_grad_(False) |
model.eval() |
model = model |
import matplotlib.pyplot as plt |
from matplotlib.patches import FancyArrowPatch |
from PIL import Image |
import numpy as np |
from torchvision import transforms |
def draw_arrows_matplotlib(img, selected_points, zero_length): |
""" |
Draw arrows on the image using matplotlib for better quality arrows and dots. |
""" |
fig, ax = plt.subplots() |
ax.imshow(img) |
for i in range(0, len(selected_points), 2): |
start_point = selected_points[i] |
end_point = selected_points[i + 1] |
if start_point == end_point or zero_length: |
ax.scatter(start_point[0], start_point[1], color='red', s=100) |
else: |
arrow = FancyArrowPatch((start_point[0], start_point[1]), (end_point[0], end_point[1]), |
color='green', linewidth=2, arrowstyle='->', mutation_scale=15) |
ax.add_patch(arrow) |
ax.scatter(start_point[0], start_point[1], color='green', s=100) |
ax.scatter(end_point[0], end_point[1], color='green', s=100) |
fig.canvas.draw() |
img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
plt.close(fig) |
return img_array |
import os |
def get_c(x, points): |
x = utils.imagenet_normalize(x) |
with torch.no_grad(): |
counterfactual = model.get_counterfactual(x, points) |
return counterfactual |
with gr.Blocks() as demo: |
with gr.Row(): |
gr.Markdown('''# Scene editing interventions with Counterfactual World Models! |
''') |
with gr.Tab(label='Image'): |
with gr.Row(): |
with gr.Column(): |
original_image = gr.State(value=None) |
original_image_high_res = gr.State(value=None) |
input_image = gr.Image(type="numpy", label="Upload Image") |
selected_points = gr.State([]) |
zero_length_toggle = gr.Checkbox(label="Select patches to be kept fixed", value=False) |
with gr.Row(): |
gr.Markdown('1. **Click on the image** to specify patch motion by selecting a start and end point. \n 2. After selecting the points to move, enable the **"Select patches to be kept fixed"** checkbox to choose a few points to keep fixed. \n 3. **Click "Run Model"** to visualize the result of the edit.') |
undo_button = gr.Button('Undo last action') |
clear_button = gr.Button('Clear All') |
run_model_button = gr.Button('Run Model') |
with gr.Tab(label='Intervention'): |
output_image = gr.Image(type='numpy') |
def resize_to_square(img, size=512): |
print("Resizing image to square") |
img = Image.fromarray(img) |
transform = transforms.Compose([ |
transforms.Resize((size, size)), |
]) |
img = transform(img) |
return np.array(img) |
def load_img(evt: gr.SelectData): |
img_path = evt.value['image']['path'] |
img = np.array(Image.open(img_path)) |
resized_img = resize_to_square(img) |
return resized_img, resized_img, img, [] |
def store_img(img): |
resized_img = resize_to_square(img) |
print(f"Image uploaded with shape: {resized_img.shape}") |
return resized_img, resized_img, img, [] |
with gr.Row(): |
with gr.Column(): |
gallery = gr.Gallery( ["./assets/desk_1.jpg", "./assets/glasses.jpg", "./assets/stick_fig_1.jpg", "./assets/watering_pot.jpg", "./assets/jordan.jpeg"], columns=5, allow_preview=False, label="Select an example image to test") |
gallery.select(load_img, outputs=[input_image, original_image, original_image_high_res, selected_points]) |
input_image.upload(store_img, [input_image], [input_image, original_image, original_image_high_res, selected_points]) |
def get_point(img, sel_pix, zero_length, evt: gr.SelectData): |
sel_pix.append(evt.index) |
if zero_length: |
point = sel_pix[-1] |
cv2.circle(img, point, dot_radius, dot_color_fixed, dot_thickness, lineType=cv2.LINE_AA) |
sel_pix.append(evt.index) |
else: |
if len(sel_pix) % 2 == 1: |
start_point = sel_pix[-1] |
cv2.circle(img, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) |
if len(sel_pix) % 2 == 0: |
start_point = sel_pix[-2] |
end_point = sel_pix[-1] |
cv2.arrowedLine(img, start_point, end_point, arrow_color, thickness, tipLength=tip_length, line_type=cv2.LINE_AA) |
cv2.circle(img, end_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) |
return img if isinstance(img, np.ndarray) else np.array(img) |
input_image.select(get_point, [input_image, selected_points, zero_length_toggle], [input_image]) |
def undo_arrows(orig_img, sel_pix, zero_length): |
temp = orig_img.copy() |
if len(sel_pix) >= 2: |
sel_pix.pop() |
sel_pix.pop() |
for i in range(0, len(sel_pix), 2): |
start_point = sel_pix[i] |
end_point = sel_pix[i + 1] |
if start_point == end_point: |
color = dot_color_fixed |
else: |
cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length) |
color = arrow_color |
cv2.circle(temp, start_point, dot_radius, color, dot_thickness) |
cv2.circle(temp, end_point, dot_radius, color, dot_thickness) |
if len(sel_pix) == 1: |
start_point = sel_pix[0] |
cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness) |
return temp if isinstance(temp, np.ndarray) else np.array(temp) |
undo_button.click(undo_arrows, [original_image, selected_points, zero_length_toggle], [input_image]) |
def clear_all_points(orig_img, sel_pix): |
sel_pix.clear() |
return orig_img |
clear_button.click(clear_all_points, [original_image, selected_points], [input_image]) |
def run_model_on_points(points, input_image, original_image): |
H = input_image.shape[0] |
W = input_image.shape[1] |
factor = 256/H |
points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor |
points = points[:, [1, 0, 3, 2]] |
img = Image.fromarray(original_image) |
transform = transforms.Compose([ |
transforms.Resize((256, 256)), |
]) |
img = np.array(transform(img)) |
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 |
img = img[None] |
x = img[:, :, None].expand(-1, -1, 2, -1, -1) |
counterfactual = get_c(x, points) |
counterfactual = counterfactual.squeeze() |
counterfactual = counterfactual.clamp(0, 1).permute(1,2,0).detach().cpu().numpy() |
return counterfactual |
run_model_button.click(run_model_on_points, [selected_points, input_image, original_image_high_res], [output_image]) |
demo.queue().launch(inbrowser=True, share=True) |