Spaces:
Running
on
Zero
Running
on
Zero
# import os | |
# os.system("pip uninstall -y gradio") | |
# os.system("pip install gradio==3.41.0") | |
import os | |
import copy | |
from PIL import Image | |
import matplotlib | |
import numpy as np | |
import gradio as gr | |
from utils import load_mask, load_mask_edit | |
from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean | |
from pathlib import Path | |
from PIL import Image | |
from functools import partial | |
from main import run_main | |
import time | |
LENGTH=512 #length of the square area displaying/editing images | |
TRANSPARENCY = 150 # transparency of the mask in display | |
def add_mask(mask_np_list_updated, mask_label_list): | |
mask_new = np.zeros_like(mask_np_list_updated[0]) | |
mask_np_list_updated.append(mask_new) | |
mask_label_list.append("new") | |
return mask_np_list_updated, mask_label_list | |
def create_segmentation(mask_np_list): | |
viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list)) | |
segmentation = 0 | |
for i, m in enumerate(mask_np_list): | |
color = matplotlib.colors.to_rgb(viridis(i)) | |
color_mat = np.ones_like(m) | |
color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2) | |
color_mat = color_mat * m[:,:,np.newaxis] | |
segmentation += color_mat | |
segmentation = Image.fromarray(np.uint8(segmentation*255)) | |
return segmentation | |
def load_mask_ui(input_folder="example_tmp",load_edit = False): | |
if not load_edit: | |
mask_list, mask_label_list = load_mask(input_folder) | |
else: | |
mask_list, mask_label_list = load_mask_edit(input_folder) | |
mask_np_list = [] | |
for m in mask_list: | |
mask_np_list. append( m.cpu().numpy()) | |
return mask_np_list, mask_label_list | |
def load_image_ui(load_edit, input_folder="example_tmp"): | |
try: | |
for img_path in Path(input_folder).iterdir(): | |
if img_path.name in ["img_512.png"]: | |
image = Image.open(img_path) | |
mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit) | |
image = image.convert('RGB') | |
segmentation = create_segmentation(mask_np_list) | |
print("!!", len(mask_np_list)) | |
max_val = len(mask_np_list)-1 | |
sliderup = gr.Slider(value = 0, minimum=0, maximum=max_val, step=1, interactive=True) | |
return image, segmentation, mask_np_list, mask_label_list, image, sliderup, sliderup | |
except: | |
print("Image folder invalid: The folder should contain image.png") | |
return None, None, None, None, None, None | |
def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128): | |
backimg_solid_np = np.array(backimg) | |
bimg = backimg.copy() | |
fimg = foreimg.copy() | |
fimg.putalpha(transparency) | |
bimg.paste(fimg, (0,0), fimg) | |
bimg_np = np.array(bimg) | |
mask_np = mask_np[:,:,np.newaxis] | |
try: | |
new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np | |
return Image.fromarray(new_img_np) | |
except: | |
import pdb; pdb.set_trace() | |
def show_segmentation(image, segmentation, flag): | |
if flag is False: | |
flag = True | |
mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8) | |
image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY) | |
return image_edit, flag | |
else: | |
flag = False | |
return image,flag | |
def edit_mask_add(canvas, image, idx, mask_np_list): | |
mask_sel = mask_np_list[idx] | |
mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.) | |
mask_np_list_updated = [] | |
for midx, m in enumerate(mask_np_list): | |
if midx == idx: | |
mask_np_list_updated.append(mask_union(mask_sel, mask_new)) | |
else: | |
mask_np_list_updated.append(m) | |
priority_list = [0 for _ in range(len(mask_np_list_updated))] | |
priority_list[idx] = 1 | |
mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list) | |
mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8) | |
segmentation = create_segmentation(mask_np_list_updated) | |
image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY) | |
return mask_np_list_updated, image_edit | |
def slider_release(index, image, mask_np_list_updated, mask_label_list): | |
if index > len(mask_np_list_updated): | |
return image, "out of range" | |
else: | |
mask_np = mask_np_list_updated[index] | |
mask_label = mask_label_list[index] | |
segmentation = create_segmentation(mask_np_list_updated) | |
new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY) | |
return new_image, mask_label | |
def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"): | |
print(mask_np_list_updated) | |
try: | |
assert np.all(sum(mask_np_list_updated)==1) | |
except: | |
print("please check mask") | |
# plt.imsave( "out_mask.png", mask_list_edit[0]) | |
import pdb; pdb.set_trace() | |
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)): | |
# np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask ) | |
np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask ) | |
savepath = os.path.join(input_folder, "seg_current.png") | |
visualize_mask_list_clean(mask_np_list_updated, savepath) | |
def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"): | |
print(mask_np_list_updated) | |
try: | |
assert np.all(sum(mask_np_list_updated)==1) | |
except: | |
print("please check mask") | |
# plt.imsave( "out_mask.png", mask_list_edit[0]) | |
import pdb; pdb.set_trace() | |
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)): | |
np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask) | |
savepath = os.path.join(input_folder, "seg_edited.png") | |
visualize_mask_list_clean(mask_np_list_updated, savepath) | |
def image_change(): | |
directory_path = "./example_tmp/" | |
for filename in os.listdir(directory_path): | |
file_path = os.path.join(directory_path, filename) | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
return gr.Button("1.2 Load original masks",visible = False) | |
def button_clickable(is_clickable): | |
return gr.Button(interactive=is_clickable) | |
def load_pil_img(): | |
from PIL import Image | |
return Image.open("example_tmp/text/out_text_0.png") | |
import shutil | |
if os.path.isdir("./example_tmp"): | |
shutil.rmtree("./example_tmp") | |
from segment import run_segmentation | |
with gr.Blocks() as demo: | |
image = gr.State() # store mask | |
image_loaded = gr.State() | |
segmentation = gr.State() | |
mask_np_list = gr.State([]) | |
mask_label_list = gr.State([]) | |
mask_np_list_updated = gr.State([]) | |
true = gr.State(True) | |
false = gr.State(False) | |
block_flag = gr.State(0) | |
num_tokens_global = gr.State(5) | |
with gr.Row(): | |
gr.Markdown("""# D-Edit""") | |
with gr.Tab(label="1 Edit mask"): | |
with gr.Row(): | |
with gr.Column(): | |
canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True) | |
segment_button = gr.Button("1.1 Run segmentation") | |
text_button = gr.Button("Waiting 1.1 to complete",visible = False) | |
flag = gr.State(False) | |
# mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!! | |
mask_np_list_updated = mask_np_list | |
with gr.Column(): | |
result_info0 = gr.Text(label="Response") | |
gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Optional)</p>""") | |
slider = gr.Slider(0, 20, step=1, label = 'mask id', interactive=False) | |
label = gr.Textbox() | |
slider.release(slider_release, | |
inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list], | |
outputs= [canvas, label] | |
) | |
segment_button.click(run_segmentation, | |
[canvas] , | |
[text_button, result_info0] ) | |
canvas.upload(image_change, inputs=[], outputs=[text_button]) | |
with gr.Tab(label="2 Optimization"): | |
with gr.Row(): | |
with gr.Column(): | |
result_info = gr.Text(label="Response") | |
opt_flag = gr.State(0) | |
gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""") | |
num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True) | |
num_tokens_global = num_tokens | |
embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True ) | |
max_emb_train_steps = gr.Number(value="200", label="embedding optimization: Training steps", interactive= True ) | |
diffusion_model_learning_rate = gr.Textbox(value="0.00005", label="UNet Optimization: Learning rate", interactive= True ) | |
max_diffusion_train_steps = gr.Number(value="200", label="UNet Optimization: Learning rate: Training steps", interactive= True ) | |
train_batch_size = gr.Number(value="5", label="Batch size", interactive= True ) | |
gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True ) | |
add_button = gr.Button("Run optimization") | |
def run_optimization_wrapper ( | |
opt_flag, | |
num_tokens, | |
embedding_learning_rate , | |
max_emb_train_steps , | |
diffusion_model_learning_rate , | |
max_diffusion_train_steps, | |
train_batch_size, | |
gradient_accumulation_steps, | |
): | |
run_optimization = partial( | |
run_main, | |
num_tokens=int(num_tokens), | |
embedding_learning_rate = float(embedding_learning_rate), | |
max_emb_train_steps = int(max_emb_train_steps), | |
diffusion_model_learning_rate= float(diffusion_model_learning_rate), | |
max_diffusion_train_steps = int(max_diffusion_train_steps), | |
train_batch_size=int(train_batch_size), | |
gradient_accumulation_steps=int(gradient_accumulation_steps) | |
) | |
run_optimization() | |
print('finish') | |
return "Optimization finished!" | |
def immediate_update(): | |
return gr.Button("Processing...", interactive=False) | |
def immediate_update2(): | |
return gr.Button("Run Optimization (Check Log for Completion).", interactive=True) | |
add_button.click(fn=immediate_update, inputs=[], outputs=[add_button]) | |
add_button.click(run_optimization_wrapper, | |
inputs = [ | |
opt_flag, | |
num_tokens, | |
embedding_learning_rate , | |
max_emb_train_steps , | |
diffusion_model_learning_rate , | |
max_diffusion_train_steps, | |
train_batch_size, | |
gradient_accumulation_steps | |
], | |
outputs = [result_info], api_name=False, concurrency_limit=45) | |
add_button.click(fn=immediate_update2, inputs=[], outputs=[add_button]) | |
#add_button.update() | |
def change_text(): | |
return gr.Textbox("Optimization Finished!", interactive = False) | |
with gr.Tab(label="3 Editing"): | |
with gr.Tab(label="3.1 Text-based editing"): | |
with gr.Row(): | |
with gr.Column(): | |
canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True,visible = True) | |
# canvas_text_edit = gr.Gallery(label = "Edited results") | |
with gr.Column(): | |
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""") | |
tgt_prompt = gr.Textbox(value="White bag", label="Editing: Text prompt", interactive= True ) | |
slider2 = gr.Slider(0, 20, step=1, label = 'mask id', interactive=False) | |
#tgt_index = gr.Number(value="0", label="Editing: Object index", interactive= True ) | |
guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True ) | |
num_sampling_steps = gr.Number(value="50", label="Editing: Sampling steps", interactive= True ) | |
edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True ) | |
strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True ) | |
add_button = gr.Button("Run Editing (Check Log for Completion)") | |
def run_edit_text_wrapper( | |
num_tokens, | |
guidance_scale, | |
num_sampling_steps , | |
strength , | |
edge_thickness, | |
tgt_prompt , | |
tgt_index | |
): | |
run_edit_text = partial( | |
run_main, | |
load_trained=True, | |
text=True, | |
num_tokens = int(num_tokens_global.value), | |
guidance_scale = float(guidance_scale), | |
num_sampling_steps = int(num_sampling_steps), | |
strength = float(strength), | |
edge_thickness = int(edge_thickness), | |
num_imgs = 1, | |
tgt_prompt = tgt_prompt, | |
tgt_index = int(tgt_index) | |
) | |
run_edit_text() | |
return load_pil_img() | |
add_button.click(run_edit_text_wrapper, | |
inputs = [num_tokens_global, | |
guidance_scale, | |
num_sampling_steps, | |
strength , | |
edge_thickness, | |
tgt_prompt , | |
slider2 | |
], | |
outputs = [canvas_text_edit],queue=True, | |
) | |
slider.change( | |
lambda x: x, | |
inputs=[slider], | |
outputs=[slider2] | |
) | |
slider2.change( | |
lambda x: x, | |
inputs=[slider2], | |
outputs=[slider] | |
) | |
text_button.click(load_image_ui, [false] , | |
[image_loaded, segmentation, mask_np_list, mask_label_list, canvas, slider, slider2] ) | |
demo.queue().launch(debug=True) | |