Spaces:
Sleeping
Sleeping
import sys | |
sys.path.append("../") | |
import os | |
import re | |
import time | |
import datetime | |
from copy import deepcopy | |
import numpy as np | |
import cv2 | |
import torch | |
import torch.nn.functional as F | |
import gradio as gr | |
from PIL import Image | |
from PIL.ImageOps import exif_transpose | |
from safetensors.torch import load_file | |
from utils.flow_utils import flow_to_image, resize_flow | |
from flowgen.models import UnetGenerator | |
from flowdiffusion.pipeline import FlowDiffusionPipeline | |
LENGTH = 512 | |
FLOWGAN_RESOLUTION = [256, 256] # HxW | |
FLOWDIFFUSION_RESOLUTION = [512, 512] # HxW | |
def process_img(image): | |
if image["composite"] is not None and not np.all(image["composite"] == 0): | |
original_image = Image.fromarray(image["composite"]).resize((LENGTH, LENGTH), Image.BICUBIC) | |
original_image = np.array(exif_transpose(original_image)) | |
return original_image, [], gr.Image(value=deepcopy(original_image), interactive=False) | |
else: | |
return ( | |
gr.Image(value=None, interactive=False), | |
[], | |
gr.Image(value=None, interactive=False), | |
) | |
def get_points(img, sel_pix, evt: gr.SelectData): | |
sel_pix.append(evt.index) | |
print(sel_pix) | |
points = [] | |
for idx, point in enumerate(sel_pix): | |
if idx % 2 == 0: | |
cv2.circle(img, tuple(point), 4, (255, 0, 0), -1) | |
else: | |
cv2.circle(img, tuple(point), 4, (0, 0, 255), -1) | |
points.append(tuple(point)) | |
if len(points) == 2: | |
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 2, tipLength=0.5) | |
points = [] | |
img = img if isinstance(img, np.ndarray) else np.array(img) | |
return img | |
def display_points(img, predefined_points, save_results): | |
if predefined_points != "": | |
predefined_points = predefined_points.split() | |
predefined_points = [int(re.sub(r'[^0-9]', '', point)) for point in predefined_points] | |
processed_points = [] | |
for i, point in enumerate(predefined_points): | |
if i % 2 == 0: | |
processed_points.append([point, predefined_points[i+1]]) | |
selected_points = processed_points | |
print(selected_points) | |
points = [] | |
for idx, point in enumerate(selected_points): | |
if idx % 2 == 0: | |
cv2.circle(img, tuple(point), 4, (255, 0, 0), -1) | |
else: | |
cv2.circle(img, tuple(point), 4, (0, 0, 255), -1) | |
points.append(tuple(point)) | |
if len(points) == 2: | |
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 2, tipLength=0.5) | |
points = [] | |
img = img if isinstance(img, np.ndarray) else np.array(img) | |
if save_results: | |
if not os.path.isdir("results/drag_inst_viz"): | |
os.makedirs("results/drag_inst_viz") | |
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") | |
to_save_img = Image.fromarray(img) | |
to_save_img.save(f"results/drag_inst_viz/{save_prefix}.png") | |
return img | |
def undo_points_image(original_image): | |
if original_image is not None: | |
return original_image, [] | |
else: | |
return gr.Image(value=None, interactive=False), [] | |
def clear_all(): | |
return ( | |
gr.Image(value=None, interactive=True), | |
gr.Image(value=None, interactive=False), | |
gr.Image(value=None, interactive=False), | |
[], | |
None | |
) | |
class InstantDragPipeline: | |
def __init__(self, seed=9999, device="cuda", dtype=torch.float16): | |
self.seed = seed | |
self.device = device | |
self.dtype = dtype | |
self.generator = torch.Generator(device=device).manual_seed(seed) | |
self.flowgen_ckpt, self.flowdiffusion_ckpt = None, None | |
self.model_config = dict() | |
def build_model(self): | |
print("Building model...") | |
if self.flowgen_ckpt != self.model_config["flowgen_ckpt"]: | |
self.flowgen = UnetGenerator(input_nc=5, output_nc=2) | |
self.flowgen.load_state_dict( | |
load_file(os.path.join("checkpoints/", self.model_config["flowgen_ckpt"]), device="cpu") | |
) | |
self.flowgen.to(self.device) | |
self.flowgen.eval() | |
self.flowgen_ckpt = self.model_config["flowgen_ckpt"] | |
if self.flowdiffusion_ckpt != self.model_config["flowdiffusion_ckpt"]: | |
self.flowdiffusion = FlowDiffusionPipeline.from_pretrained( | |
os.path.join("checkpoints/", self.model_config["flowdiffusion_ckpt"]), | |
torch_dtype=self.dtype, | |
safety_checker=None | |
) | |
self.flowdiffusion.to(self.device) | |
self.flowdiffusion_ckpt = self.model_config["flowdiffusion_ckpt"] | |
def drag(self, original_image, selected_points, save_results): | |
scale = self.model_config["flowgen_output_scale"] | |
original_image = torch.tensor(original_image).permute(2, 0, 1).unsqueeze(0).float() # 1, 3, 512, 512 | |
original_image = 2 * (original_image / 255.) - 1 # Normalize to [-1, 1] | |
original_image = original_image.to(self.device) | |
source_points = [] | |
target_points = [] | |
for idx, point in enumerate(selected_points): | |
cur_point = torch.tensor([point[0], point[1]]) # x, y | |
if idx % 2 == 0: | |
source_points.append(cur_point) | |
else: | |
target_points.append(cur_point) | |
torch.cuda.synchronize() | |
start_time = time.time() | |
# Generate sparse flow vectors | |
point_vector_map = torch.zeros((1, 2, LENGTH, LENGTH)) | |
for source_point, target_point in zip(source_points, target_points): | |
cur_x, cur_y = source_point[0], source_point[1] | |
target_x, target_y = target_point[0], target_point[1] | |
vec_x = target_x - cur_x | |
vec_y = target_y - cur_y | |
point_vector_map[0, 0, int(cur_y), int(cur_x)] = vec_x | |
point_vector_map[0, 1, int(cur_y), int(cur_x)] = vec_y | |
point_vector_map = point_vector_map.to(self.device) | |
# Sample-wise normalize the flow vectors | |
factor_x = torch.amax(torch.abs(point_vector_map[:, 0, :, :]), dim=(1, 2)).view(-1, 1, 1).to(self.device) | |
factor_y = torch.amax(torch.abs(point_vector_map[:, 1, :, :]), dim=(1, 2)).view(-1, 1, 1).to(self.device) | |
if factor_x >= 1e-8: # Avoid division by zero | |
point_vector_map[:, 0, :, :] /= factor_x | |
if factor_y >= 1e-8: # Avoid division by zero | |
point_vector_map[:, 1, :, :] /= factor_y | |
with torch.inference_mode(): | |
gan_input_image = F.interpolate(original_image, size=FLOWGAN_RESOLUTION, mode="bicubic") # 256 x 256 | |
point_vector_map = F.interpolate(point_vector_map, size=FLOWGAN_RESOLUTION, mode="bicubic") # 256 x 256 | |
gan_input = torch.cat([gan_input_image, point_vector_map], dim=1) | |
flow = self.flowgen(gan_input) # -1 ~ 1 | |
if scale == -1.0: | |
flow[:, 0, :, :] *= 1.0 / torch.amax(torch.abs(flow[:, 0, :, :]), dim=(1, 2)).view(-1, 1, 1) # force the range to be [-1 ~ 1] | |
flow[:, 1, :, :] *= 1.0 / torch.amax(torch.abs(flow[:, 1, :, :]), dim=(1, 2)).view(-1, 1, 1) # force the range to be [-1 ~ 1] | |
else: | |
flow[:, 0, :, :] *= scale # manually adjust the scale | |
flow[:, 1, :, :] *= scale # manually adjust the scale | |
if factor_x >= 1e-8: | |
flow[:, 0, :, :] *= factor_x * (FLOWGAN_RESOLUTION[1] / original_image.shape[3]) # width | |
else: | |
flow[:, 0, :, :] *= 0 | |
if factor_y >= 1e-8: | |
flow[:, 1, :, :] *= factor_y * (FLOWGAN_RESOLUTION[0] / original_image.shape[2]) # height | |
else: | |
flow[:, 1, :, :] *= 0 | |
resized_flow = resize_flow(flow, (FLOWDIFFUSION_RESOLUTION[0]//8, FLOWDIFFUSION_RESOLUTION[1]//8), scale_type="normalize_fixed") | |
kwargs = { | |
"image": original_image.to(self.dtype), | |
"flow": resized_flow.to(self.dtype), | |
"num_inference_steps": self.model_config['n_inference_step'], | |
"image_guidance_scale": self.model_config['image_guidance'], | |
"flow_guidance_scale": self.model_config['flow_guidance'], | |
"generator": self.generator, | |
} | |
edited_image = self.flowdiffusion(**kwargs).images[0] | |
end_time = time.time() | |
inference_time = end_time - start_time | |
print(f"Inference Time: {inference_time} seconds") | |
if save_results: | |
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") | |
if not os.path.isdir("results/flows"): | |
os.makedirs("results/flows") | |
np.save(f"results/flows/{save_prefix}.npy", flow[0].detach().cpu().numpy()) | |
if not os.path.isdir("results/flow_visualized"): | |
os.makedirs("results/flow_visualized") | |
flow_to_image(flow[0].detach()).save(f"results/flow_visualized/{save_prefix}.png") | |
if not os.path.isdir("results/edited_images"): | |
os.makedirs("results/edited_images") | |
edited_image.save(f"results/edited_images/{save_prefix}.png") | |
if not os.path.isdir("results/drag_instructions"): | |
os.makedirs("results/drag_instructions") | |
with open(f"results/drag_instructions/{save_prefix}.txt", "w") as f: | |
f.write(str(selected_points)) | |
edited_image = np.array(edited_image) | |
return edited_image | |
def run(self, original_image, selected_points, | |
flowgen_ckpt, flowdiffusion_ckpt, image_guidance, flow_guidance, flowgen_output_scale, | |
num_steps, save_results): | |
self.model_config = { | |
"flowgen_ckpt": flowgen_ckpt, | |
"flowdiffusion_ckpt": flowdiffusion_ckpt, | |
"image_guidance": image_guidance, | |
"flow_guidance": flow_guidance, | |
"flowgen_output_scale": flowgen_output_scale, | |
"n_inference_step": num_steps | |
} | |
self.build_model() | |
edited_image = self.drag(original_image, selected_points, save_results) | |
return edited_image |