d-edit / app.py
niulx's picture
Update app.py
c4781c0 verified
raw
history blame
15.8 kB
import os
import copy
import spaces
from PIL import Image
import matplotlib
import numpy as np
import gradio as gr
from utils import load_mask, load_mask_edit
from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean
from pathlib import Path
from PIL import Image
from functools import partial
from main import run_main
import time
LENGTH=512 #length of the square area displaying/editing images
TRANSPARENCY = 150 # transparency of the mask in display
def add_mask(mask_np_list_updated, mask_label_list):
mask_new = np.zeros_like(mask_np_list_updated[0])
mask_np_list_updated.append(mask_new)
mask_label_list.append("new")
return mask_np_list_updated, mask_label_list
def create_segmentation(mask_np_list):
viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list))
segmentation = 0
for i, m in enumerate(mask_np_list):
color = matplotlib.colors.to_rgb(viridis(i))
color_mat = np.ones_like(m)
color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2)
color_mat = color_mat * m[:,:,np.newaxis]
segmentation += color_mat
segmentation = Image.fromarray(np.uint8(segmentation*255))
return segmentation
def run_segmentation_wrapper(image):
try:
image, mask_np_list,mask_label_list = run_segmentation(image)
#image = image.convert('RGB')
segmentation = create_segmentation(mask_np_list)
print("!!", len(mask_np_list))
max_val = len(mask_np_list)-1
sliderup = gr.Slider(value = 0, minimum=0, maximum=max_val, step=1, visible=True)
gr.Info('Segmentation finish. Select mask id and move to the next step.')
return image, segmentation, mask_np_list, mask_label_list, image, sliderup, sliderup , 'Segmentation finish. Select mask id and move to the next step.'
except:
sliderup = gr.Slider(value = 0, minimum=0, maximum=1, step=1, visible=False)
gr.Warning('Please upload an image before proceeding.')
return None,None,None,None,None, sliderup, sliderup , 'Please upload an image before proceeding.'
def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
backimg_solid_np = np.array(backimg)
bimg = backimg.copy()
fimg = foreimg.copy()
fimg.putalpha(transparency)
bimg.paste(fimg, (0,0), fimg)
bimg_np = np.array(bimg)
mask_np = mask_np[:,:,np.newaxis]
new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
return Image.fromarray(np.uint8(new_img_np))
def show_segmentation(image, segmentation, flag):
if flag is False:
flag = True
mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8)
image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY)
return image_edit, flag
else:
flag = False
return image,flag
def edit_mask_add(canvas, image, idx, mask_np_list):
mask_sel = mask_np_list[idx]
mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.)
mask_np_list_updated = []
for midx, m in enumerate(mask_np_list):
if midx == idx:
mask_np_list_updated.append(mask_union(mask_sel, mask_new))
else:
mask_np_list_updated.append(m)
priority_list = [0 for _ in range(len(mask_np_list_updated))]
priority_list[idx] = 1
mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list)
mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8)
segmentation = create_segmentation(mask_np_list_updated)
image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY)
return mask_np_list_updated, image_edit
def slider_release(index, image, mask_np_list_updated, mask_label_list):
if index > len(mask_np_list_updated)-1:
return image, "out of range"
else:
mask_np = mask_np_list_updated[index]
mask_label = mask_label_list[index]
segmentation = create_segmentation(mask_np_list_updated)
new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
return new_image, mask_label
def image_change():
return gr.Slider(value = 0, minimum=0, maximum=1, step=1, visible=False),gr.Button("Run Editing (Check log for progress.)",interactive = False)
def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
print(mask_np_list_updated)
try:
assert np.all(sum(mask_np_list_updated)==1)
except:
print("please check mask")
# plt.imsave( "out_mask.png", mask_list_edit[0])
import pdb; pdb.set_trace()
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
# np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask )
np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask )
savepath = os.path.join(input_folder, "seg_current.png")
visualize_mask_list_clean(mask_np_list_updated, savepath)
def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
print(mask_np_list_updated)
try:
assert np.all(sum(mask_np_list_updated)==1)
except:
print("please check mask")
# plt.imsave( "out_mask.png", mask_list_edit[0])
import pdb; pdb.set_trace()
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask)
savepath = os.path.join(input_folder, "seg_edited.png")
visualize_mask_list_clean(mask_np_list_updated, savepath)
def button_clickable(is_clickable):
return gr.Button(interactive=is_clickable)
def load_pil_img():
from PIL import Image
return Image.open("example_tmp/text/out_text_0.png")
import shutil
if os.path.isdir("./example_tmp"):
shutil.rmtree("./example_tmp")
from segment import run_segmentation
with gr.Blocks() as demo:
image = gr.State() # store mask
image_loaded = gr.State()
segmentation = gr.State()
mask_np_list = gr.State([])
mask_label_list = gr.State([])
mask_np_list_updated = gr.State([])
true = gr.State(True)
false = gr.State(False)
block_flag = gr.State(0)
num_tokens_global = gr.State(5)
with gr.Row():
gr.Markdown("""# D-Edit""")
with gr.Tab(label="1 Edit mask"):
with gr.Row():
with gr.Column():
canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
result_info0 = gr.Text(label="Response")
segment_button = gr.Button("Run segmentation")
flag = gr.State(False)
# mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
mask_np_list_updated = mask_np_list
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Optional)</p>""")
slider = gr.Slider(0, 20, step=1, label = 'mask id', visible=False)
label = gr.Text(label='label')
slider.release(slider_release,
inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
outputs= [canvas, label]
)
with gr.Tab(label="2 Optimization"):
with gr.Row():
with gr.Column():
result_info = gr.Text(label="Response")
opt_flag = gr.State(0)
gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
num_tokens_global = num_tokens
embedding_learning_rate = gr.Textbox(value="0.00005", label="Embedding optimization: Learning rate", interactive= True )
max_emb_train_steps = gr.Number(value="30", label="embedding optimization: Training steps", interactive= True )
diffusion_model_learning_rate = gr.Textbox(value="0.00002", label="UNet Optimization: Learning rate", interactive= True )
max_diffusion_train_steps = gr.Number(value="30", label="UNet Optimization: Learning rate: Training steps", interactive= True )
train_batch_size = gr.Number(value="16", label="Batch size", interactive= True )
gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True )
add_button = gr.Button("Run optimization")
def run_optimization_wrapper (
mask_np_list,
mask_label_list,
image,
opt_flag,
num_tokens,
embedding_learning_rate ,
max_emb_train_steps ,
diffusion_model_learning_rate ,
max_diffusion_train_steps,
train_batch_size,
gradient_accumulation_steps,
):
try:
run_optimization = partial(
run_main,
mask_np_list=mask_np_list,
mask_label_list=mask_label_list,
image_gt=np.array(image),
num_tokens=int(num_tokens),
embedding_learning_rate = float(embedding_learning_rate),
max_emb_train_steps = int(max_emb_train_steps),
diffusion_model_learning_rate= float(diffusion_model_learning_rate),
max_diffusion_train_steps = int(max_diffusion_train_steps),
train_batch_size=int(train_batch_size),
gradient_accumulation_steps=int(gradient_accumulation_steps)
)
run_optimization()
gr.Info("Optimization Finished! Move to the next step.")
return "Optimization finished! Move to the next step.",gr.Button("Run Editing (Check log for progress.)",interactive = True)
except Exception as e:
gr.Error("e")
return "Error: use a smaller batch size or try latter.",gr.Button("Run Editing (Check log for progress.)",interactive = False)
with gr.Tab(label="3 Editing"):
with gr.Tab(label="3.1 Text-based editing"):
with gr.Row():
with gr.Column():
canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True,visible = True)
# canvas_text_edit = gr.Gallery(label = "Edited results")
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
tgt_prompt = gr.Textbox(value="White bag", label="Editing: Text prompt", interactive= True )
slider2 = gr.Slider(0, 20, step=1, label = 'mask id', visible=False)
#tgt_index = gr.Number(value="0", label="Editing: Object index", interactive= True )
guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
num_sampling_steps = gr.Number(value="50", label="Editing: Sampling steps", interactive= True )
edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True )
strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
add_button2 = gr.Button("Run Editing (Check log for progress.)",interactive = False)
def run_edit_text_wrapper(
mask_np_list,
mask_label_list,
image,
num_tokens,
guidance_scale,
num_sampling_steps ,
strength ,
edge_thickness,
tgt_prompt ,
tgt_index
):
run_edit_text = partial(
run_main,
mask_np_list=mask_np_list,
mask_label_list=mask_label_list,
image_gt=np.array(image),
load_trained=True,
text=True,
num_tokens = int(num_tokens_global.value),
guidance_scale = float(guidance_scale),
num_sampling_steps = int(num_sampling_steps),
strength = float(strength),
edge_thickness = int(edge_thickness),
num_imgs = 1,
tgt_prompt = tgt_prompt,
tgt_index = int(tgt_index)
)
run_edit_text()
gr.Info('Image editing completed.')
return load_pil_img()
canvas.upload(image_change, inputs=[], outputs=[slider,add_button2])
add_button.click(run_optimization_wrapper,
inputs = [
mask_np_list,
mask_label_list,
image_loaded,
opt_flag,
num_tokens,
embedding_learning_rate ,
max_emb_train_steps ,
diffusion_model_learning_rate ,
max_diffusion_train_steps,
train_batch_size,
gradient_accumulation_steps
],
outputs = [result_info,add_button2], api_name=False, concurrency_limit=45)
add_button2.click(run_edit_text_wrapper,
inputs = [ mask_np_list,
mask_label_list,
image_loaded,num_tokens_global,
guidance_scale,
num_sampling_steps,
strength ,
edge_thickness,
tgt_prompt ,
slider2
],
outputs = [canvas_text_edit],queue=True)
slider.change(
lambda x: x,
inputs=[slider],
outputs=[slider2]
)
segment_button.click(run_segmentation_wrapper,
[canvas] ,
[image_loaded, segmentation, mask_np_list, mask_label_list, canvas, slider, slider2, result_info0] )
demo.queue().launch(debug=True)