InstantDrag / InstDrag /demo /demo_utils.py
smarques's picture
checkout INstantDrag
fd0db3a
raw
history blame
10.2 kB
import sys
sys.path.append("../")
import os
import re
import time
import datetime
from copy import deepcopy
import numpy as np
import cv2
import torch
import torch.nn.functional as F
import gradio as gr
from PIL import Image
from PIL.ImageOps import exif_transpose
from safetensors.torch import load_file
from utils.flow_utils import flow_to_image, resize_flow
from flowgen.models import UnetGenerator
from flowdiffusion.pipeline import FlowDiffusionPipeline
LENGTH = 512
FLOWGAN_RESOLUTION = [256, 256] # HxW
FLOWDIFFUSION_RESOLUTION = [512, 512] # HxW
def process_img(image):
if image["composite"] is not None and not np.all(image["composite"] == 0):
original_image = Image.fromarray(image["composite"]).resize((LENGTH, LENGTH), Image.BICUBIC)
original_image = np.array(exif_transpose(original_image))
return original_image, [], gr.Image(value=deepcopy(original_image), interactive=False)
else:
return (
gr.Image(value=None, interactive=False),
[],
gr.Image(value=None, interactive=False),
)
def get_points(img, sel_pix, evt: gr.SelectData):
sel_pix.append(evt.index)
print(sel_pix)
points = []
for idx, point in enumerate(sel_pix):
if idx % 2 == 0:
cv2.circle(img, tuple(point), 4, (255, 0, 0), -1)
else:
cv2.circle(img, tuple(point), 4, (0, 0, 255), -1)
points.append(tuple(point))
if len(points) == 2:
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 2, tipLength=0.5)
points = []
img = img if isinstance(img, np.ndarray) else np.array(img)
return img
def display_points(img, predefined_points, save_results):
if predefined_points != "":
predefined_points = predefined_points.split()
predefined_points = [int(re.sub(r'[^0-9]', '', point)) for point in predefined_points]
processed_points = []
for i, point in enumerate(predefined_points):
if i % 2 == 0:
processed_points.append([point, predefined_points[i+1]])
selected_points = processed_points
print(selected_points)
points = []
for idx, point in enumerate(selected_points):
if idx % 2 == 0:
cv2.circle(img, tuple(point), 4, (255, 0, 0), -1)
else:
cv2.circle(img, tuple(point), 4, (0, 0, 255), -1)
points.append(tuple(point))
if len(points) == 2:
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 2, tipLength=0.5)
points = []
img = img if isinstance(img, np.ndarray) else np.array(img)
if save_results:
if not os.path.isdir("results/drag_inst_viz"):
os.makedirs("results/drag_inst_viz")
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
to_save_img = Image.fromarray(img)
to_save_img.save(f"results/drag_inst_viz/{save_prefix}.png")
return img
def undo_points_image(original_image):
if original_image is not None:
return original_image, []
else:
return gr.Image(value=None, interactive=False), []
def clear_all():
return (
gr.Image(value=None, interactive=True),
gr.Image(value=None, interactive=False),
gr.Image(value=None, interactive=False),
[],
None
)
class InstantDragPipeline:
def __init__(self, seed=9999, device="cuda", dtype=torch.float16):
self.seed = seed
self.device = device
self.dtype = dtype
self.generator = torch.Generator(device=device).manual_seed(seed)
self.flowgen_ckpt, self.flowdiffusion_ckpt = None, None
self.model_config = dict()
def build_model(self):
print("Building model...")
if self.flowgen_ckpt != self.model_config["flowgen_ckpt"]:
self.flowgen = UnetGenerator(input_nc=5, output_nc=2)
self.flowgen.load_state_dict(
load_file(os.path.join("checkpoints/", self.model_config["flowgen_ckpt"]), device="cpu")
)
self.flowgen.to(self.device)
self.flowgen.eval()
self.flowgen_ckpt = self.model_config["flowgen_ckpt"]
if self.flowdiffusion_ckpt != self.model_config["flowdiffusion_ckpt"]:
self.flowdiffusion = FlowDiffusionPipeline.from_pretrained(
os.path.join("checkpoints/", self.model_config["flowdiffusion_ckpt"]),
torch_dtype=self.dtype,
safety_checker=None
)
self.flowdiffusion.to(self.device)
self.flowdiffusion_ckpt = self.model_config["flowdiffusion_ckpt"]
def drag(self, original_image, selected_points, save_results):
scale = self.model_config["flowgen_output_scale"]
original_image = torch.tensor(original_image).permute(2, 0, 1).unsqueeze(0).float() # 1, 3, 512, 512
original_image = 2 * (original_image / 255.) - 1 # Normalize to [-1, 1]
original_image = original_image.to(self.device)
source_points = []
target_points = []
for idx, point in enumerate(selected_points):
cur_point = torch.tensor([point[0], point[1]]) # x, y
if idx % 2 == 0:
source_points.append(cur_point)
else:
target_points.append(cur_point)
torch.cuda.synchronize()
start_time = time.time()
# Generate sparse flow vectors
point_vector_map = torch.zeros((1, 2, LENGTH, LENGTH))
for source_point, target_point in zip(source_points, target_points):
cur_x, cur_y = source_point[0], source_point[1]
target_x, target_y = target_point[0], target_point[1]
vec_x = target_x - cur_x
vec_y = target_y - cur_y
point_vector_map[0, 0, int(cur_y), int(cur_x)] = vec_x
point_vector_map[0, 1, int(cur_y), int(cur_x)] = vec_y
point_vector_map = point_vector_map.to(self.device)
# Sample-wise normalize the flow vectors
factor_x = torch.amax(torch.abs(point_vector_map[:, 0, :, :]), dim=(1, 2)).view(-1, 1, 1).to(self.device)
factor_y = torch.amax(torch.abs(point_vector_map[:, 1, :, :]), dim=(1, 2)).view(-1, 1, 1).to(self.device)
if factor_x >= 1e-8: # Avoid division by zero
point_vector_map[:, 0, :, :] /= factor_x
if factor_y >= 1e-8: # Avoid division by zero
point_vector_map[:, 1, :, :] /= factor_y
with torch.inference_mode():
gan_input_image = F.interpolate(original_image, size=FLOWGAN_RESOLUTION, mode="bicubic") # 256 x 256
point_vector_map = F.interpolate(point_vector_map, size=FLOWGAN_RESOLUTION, mode="bicubic") # 256 x 256
gan_input = torch.cat([gan_input_image, point_vector_map], dim=1)
flow = self.flowgen(gan_input) # -1 ~ 1
if scale == -1.0:
flow[:, 0, :, :] *= 1.0 / torch.amax(torch.abs(flow[:, 0, :, :]), dim=(1, 2)).view(-1, 1, 1) # force the range to be [-1 ~ 1]
flow[:, 1, :, :] *= 1.0 / torch.amax(torch.abs(flow[:, 1, :, :]), dim=(1, 2)).view(-1, 1, 1) # force the range to be [-1 ~ 1]
else:
flow[:, 0, :, :] *= scale # manually adjust the scale
flow[:, 1, :, :] *= scale # manually adjust the scale
if factor_x >= 1e-8:
flow[:, 0, :, :] *= factor_x * (FLOWGAN_RESOLUTION[1] / original_image.shape[3]) # width
else:
flow[:, 0, :, :] *= 0
if factor_y >= 1e-8:
flow[:, 1, :, :] *= factor_y * (FLOWGAN_RESOLUTION[0] / original_image.shape[2]) # height
else:
flow[:, 1, :, :] *= 0
resized_flow = resize_flow(flow, (FLOWDIFFUSION_RESOLUTION[0]//8, FLOWDIFFUSION_RESOLUTION[1]//8), scale_type="normalize_fixed")
kwargs = {
"image": original_image.to(self.dtype),
"flow": resized_flow.to(self.dtype),
"num_inference_steps": self.model_config['n_inference_step'],
"image_guidance_scale": self.model_config['image_guidance'],
"flow_guidance_scale": self.model_config['flow_guidance'],
"generator": self.generator,
}
edited_image = self.flowdiffusion(**kwargs).images[0]
end_time = time.time()
inference_time = end_time - start_time
print(f"Inference Time: {inference_time} seconds")
if save_results:
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
if not os.path.isdir("results/flows"):
os.makedirs("results/flows")
np.save(f"results/flows/{save_prefix}.npy", flow[0].detach().cpu().numpy())
if not os.path.isdir("results/flow_visualized"):
os.makedirs("results/flow_visualized")
flow_to_image(flow[0].detach()).save(f"results/flow_visualized/{save_prefix}.png")
if not os.path.isdir("results/edited_images"):
os.makedirs("results/edited_images")
edited_image.save(f"results/edited_images/{save_prefix}.png")
if not os.path.isdir("results/drag_instructions"):
os.makedirs("results/drag_instructions")
with open(f"results/drag_instructions/{save_prefix}.txt", "w") as f:
f.write(str(selected_points))
edited_image = np.array(edited_image)
return edited_image
def run(self, original_image, selected_points,
flowgen_ckpt, flowdiffusion_ckpt, image_guidance, flow_guidance, flowgen_output_scale,
num_steps, save_results):
self.model_config = {
"flowgen_ckpt": flowgen_ckpt,
"flowdiffusion_ckpt": flowdiffusion_ckpt,
"image_guidance": image_guidance,
"flow_guidance": flow_guidance,
"flowgen_output_scale": flowgen_output_scale,
"n_inference_step": num_steps
}
self.build_model()
edited_image = self.drag(original_image, selected_points, save_results)
return edited_image