|
|
|
from torchvision.utils import save_image |
|
from PIL import Image |
|
from pytorch_lightning import seed_everything |
|
import subprocess |
|
from collections import OrderedDict |
|
import re |
|
import cv2 |
|
import einops |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import random |
|
import os |
|
import requests |
|
from io import BytesIO |
|
from annotator.util import resize_image, HWC3, resize_points, get_bounding_box |
|
|
|
import torch |
|
from safetensors.torch import load_file |
|
from collections import defaultdict |
|
from diffusers import StableDiffusionControlNetPipeline |
|
from diffusers import ControlNetModel, UniPCMultistepScheduler |
|
|
|
from utils.stable_diffusion_controlnet import ControlNetModel2 |
|
from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline, \ |
|
StableDiffusionControlNetInpaintMixingPipeline, prepare_mask_image |
|
|
|
|
|
from transformers import AutoProcessor, Blip2ForConditionalGeneration |
|
from diffusers import ControlNetModel, DiffusionPipeline |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
|
import PIL.Image |
|
|
|
|
|
|
|
try: |
|
from segment_anything import ( |
|
sam_model_registry, |
|
SamAutomaticMaskGenerator, |
|
SamPredictor, |
|
) |
|
except ImportError: |
|
print("segment_anything not installed") |
|
result = subprocess.run( |
|
[ |
|
"pip", |
|
"install", |
|
"git+https://github.com/facebookresearch/segment-anything.git", |
|
], |
|
check=True, |
|
) |
|
print(f"Install segment_anything {result}") |
|
from segment_anything import ( |
|
sam_model_registry, |
|
SamAutomaticMaskGenerator, |
|
SamPredictor, |
|
) |
|
if not os.path.exists("./models/sam_vit_h_4b8939.pth"): |
|
result = subprocess.run( |
|
[ |
|
"wget", |
|
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", |
|
"-P", |
|
"models", |
|
], |
|
check=True, |
|
) |
|
print(f"Download sam_vit_h_4b8939.pth {result}") |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
config_dict = OrderedDict( |
|
[ |
|
("LAION Pretrained(v0-4)-SD15", "shgao/edit-anything-v0-4-sd15"), |
|
("LAION Pretrained(v0-4)-SD21", "shgao/edit-anything-v0-4-sd21"), |
|
("LAION Pretrained(v0-3)-SD21", "shgao/edit-anything-v0-3"), |
|
("SAM Pretrained(v0-1)-SD21", "shgao/edit-anything-v0-1-1"), |
|
] |
|
) |
|
|
|
|
|
def init_sam_model(sam_generator=None, mask_predictor=None): |
|
if sam_generator is not None and mask_predictor is not None: |
|
return sam_generator, mask_predictor |
|
sam_checkpoint = "models/sam_vit_h_4b8939.pth" |
|
model_type = "default" |
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
sam.to(device=device) |
|
sam_generator = ( |
|
SamAutomaticMaskGenerator( |
|
sam) if sam_generator is None else sam_generator |
|
) |
|
mask_predictor = SamPredictor( |
|
sam) if mask_predictor is None else mask_predictor |
|
return sam_generator, mask_predictor |
|
|
|
|
|
def init_blip_processor(): |
|
blip_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") |
|
return blip_processor |
|
|
|
|
|
def init_blip_model(): |
|
blip_model = Blip2ForConditionalGeneration.from_pretrained( |
|
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto" |
|
) |
|
return blip_model |
|
|
|
|
|
def get_pipeline_embeds(pipeline, prompt, negative_prompt, device): |
|
|
|
"""Get pipeline embeds for prompts bigger than the maxlength of the pipe |
|
:param pipeline: |
|
:param prompt: |
|
:param negative_prompt: |
|
:param device: |
|
:return: |
|
""" |
|
max_length = pipeline.tokenizer.model_max_length |
|
|
|
|
|
count_prompt = len(re.split(r", ", prompt)) |
|
count_negative_prompt = len(re.split(r", ", negative_prompt)) |
|
|
|
|
|
if count_prompt >= count_negative_prompt: |
|
input_ids = pipeline.tokenizer( |
|
prompt, return_tensors="pt", truncation=False |
|
).input_ids.to(device) |
|
shape_max_length = input_ids.shape[-1] |
|
negative_ids = pipeline.tokenizer( |
|
negative_prompt, |
|
truncation=False, |
|
padding="max_length", |
|
max_length=shape_max_length, |
|
return_tensors="pt", |
|
).input_ids.to(device) |
|
else: |
|
negative_ids = pipeline.tokenizer( |
|
negative_prompt, return_tensors="pt", truncation=False |
|
).input_ids.to(device) |
|
shape_max_length = negative_ids.shape[-1] |
|
input_ids = pipeline.tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
truncation=False, |
|
padding="max_length", |
|
max_length=shape_max_length, |
|
).input_ids.to(device) |
|
|
|
concat_embeds = [] |
|
neg_embeds = [] |
|
for i in range(0, shape_max_length, max_length): |
|
concat_embeds.append(pipeline.text_encoder(input_ids[:, i : i + max_length])[0]) |
|
neg_embeds.append(pipeline.text_encoder(negative_ids[:, i : i + max_length])[0]) |
|
|
|
return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1) |
|
|
|
|
|
def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype): |
|
LORA_PREFIX_UNET = "lora_unet" |
|
LORA_PREFIX_TEXT_ENCODER = "lora_te" |
|
|
|
print('device: {}'.format(device)) |
|
if isinstance(checkpoint_path, str): |
|
state_dict = load_file(checkpoint_path, device=device) |
|
|
|
updates = defaultdict(dict) |
|
for key, value in state_dict.items(): |
|
|
|
|
|
|
|
layer, elem = key.split(".", 1) |
|
updates[layer][elem] = value |
|
|
|
|
|
for layer, elems in updates.items(): |
|
|
|
if "text" in layer: |
|
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") |
|
curr_layer = pipeline.text_encoder |
|
else: |
|
layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") |
|
curr_layer = pipeline.unet |
|
|
|
|
|
temp_name = layer_infos.pop(0) |
|
while len(layer_infos) > -1: |
|
try: |
|
curr_layer = curr_layer.__getattr__(temp_name) |
|
if len(layer_infos) > 0: |
|
temp_name = layer_infos.pop(0) |
|
elif len(layer_infos) == 0: |
|
break |
|
except Exception: |
|
if len(temp_name) > 0: |
|
temp_name += "_" + layer_infos.pop(0) |
|
else: |
|
temp_name = layer_infos.pop(0) |
|
|
|
|
|
weight_up = elems["lora_up.weight"].to(dtype) |
|
weight_down = elems["lora_down.weight"].to(dtype) |
|
alpha = elems["alpha"] |
|
if alpha: |
|
alpha = alpha.item() / weight_up.shape[1] |
|
else: |
|
alpha = 1.0 |
|
|
|
|
|
if len(weight_up.shape) == 4: |
|
curr_layer.weight.data += ( |
|
multiplier |
|
* alpha |
|
* torch.mm( |
|
weight_up.squeeze(3).squeeze(2), |
|
weight_down.squeeze(3).squeeze(2), |
|
) |
|
.unsqueeze(2) |
|
.unsqueeze(3) |
|
) |
|
else: |
|
curr_layer.weight.data += ( |
|
multiplier * alpha * torch.mm(weight_up, weight_down) |
|
) |
|
else: |
|
for ckptpath in checkpoint_path: |
|
state_dict = load_file(ckptpath, device=device) |
|
|
|
updates = defaultdict(dict) |
|
for key, value in state_dict.items(): |
|
|
|
|
|
|
|
layer, elem = key.split(".", 1) |
|
updates[layer][elem] = value |
|
|
|
|
|
for layer, elems in updates.items(): |
|
if "text" in layer: |
|
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split( |
|
"_" |
|
) |
|
curr_layer = pipeline.text_encoder |
|
else: |
|
layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") |
|
curr_layer = pipeline.unet |
|
|
|
|
|
temp_name = layer_infos.pop(0) |
|
while len(layer_infos) > -1: |
|
try: |
|
curr_layer = curr_layer.__getattr__(temp_name) |
|
if len(layer_infos) > 0: |
|
temp_name = layer_infos.pop(0) |
|
elif len(layer_infos) == 0: |
|
break |
|
except Exception: |
|
if len(temp_name) > 0: |
|
temp_name += "_" + layer_infos.pop(0) |
|
else: |
|
temp_name = layer_infos.pop(0) |
|
|
|
|
|
weight_up = elems["lora_up.weight"].to(dtype) |
|
weight_down = elems["lora_down.weight"].to(dtype) |
|
alpha = elems["alpha"] |
|
if alpha: |
|
alpha = alpha.item() / weight_up.shape[1] |
|
else: |
|
alpha = 1.0 |
|
|
|
|
|
if len(weight_up.shape) == 4: |
|
curr_layer.weight.data += ( |
|
multiplier |
|
* alpha |
|
* torch.mm( |
|
weight_up.squeeze(3).squeeze(2), |
|
weight_down.squeeze(3).squeeze(2), |
|
) |
|
.unsqueeze(2) |
|
.unsqueeze(3) |
|
) |
|
else: |
|
curr_layer.weight.data += ( |
|
multiplier * alpha * torch.mm(weight_up, weight_down) |
|
) |
|
return pipeline |
|
|
|
|
|
def make_inpaint_condition(image, image_mask): |
|
image = image / 255.0 |
|
assert ( |
|
image.shape[0:1] == image_mask.shape[0:1] |
|
), "image and image_mask must have the same image size" |
|
image[image_mask > 128] = -1.0 |
|
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) |
|
image = torch.from_numpy(image) |
|
return image |
|
|
|
|
|
def obtain_generation_model( |
|
base_model_path, |
|
lora_model_path, |
|
controlnet_path, |
|
generation_only=False, |
|
extra_inpaint=True, |
|
lora_weight=1.0, |
|
): |
|
controlnet = [] |
|
controlnet.append( |
|
ControlNetModel2.from_pretrained( |
|
controlnet_path, torch_dtype=torch.float16) |
|
) |
|
if (not generation_only) and extra_inpaint: |
|
print("Warning: ControlNet based inpainting model only support SD1.5 for now.") |
|
controlnet.append( |
|
ControlNetModel.from_pretrained( |
|
"lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 |
|
) |
|
) |
|
|
|
if generation_only and extra_inpaint: |
|
pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
base_model_path, |
|
controlnet=controlnet, |
|
torch_dtype=torch.float16, |
|
safety_checker=None, |
|
) |
|
else: |
|
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( |
|
base_model_path, |
|
controlnet=controlnet, |
|
torch_dtype=torch.float16, |
|
safety_checker=None, |
|
) |
|
if lora_model_path is not None: |
|
pipe = load_lora_weights( |
|
pipe, [lora_model_path], lora_weight, "cpu", torch.float32 |
|
) |
|
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
|
pipe.enable_xformers_memory_efficient_attention() |
|
|
|
pipe.enable_model_cpu_offload() |
|
return pipe |
|
|
|
|
|
def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0): |
|
controlnet = ControlNetModel2.from_pretrained( |
|
"lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16 |
|
) |
|
if ( |
|
base_model_path == "runwayml/stable-diffusion-v1-5" |
|
or base_model_path == "stabilityai/stable-diffusion-2-inpainting" |
|
): |
|
print("base_model_path", base_model_path) |
|
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
controlnet=controlnet, |
|
torch_dtype=torch.float16, |
|
safety_checker=None, |
|
) |
|
else: |
|
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( |
|
base_model_path, |
|
controlnet=controlnet, |
|
torch_dtype=torch.float16, |
|
safety_checker=None, |
|
) |
|
if lora_model_path is not None: |
|
pipe = load_lora_weights( |
|
pipe, [lora_model_path], lora_weight, "cpu", torch.float32 |
|
) |
|
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
|
pipe.enable_xformers_memory_efficient_attention() |
|
|
|
pipe.enable_model_cpu_offload() |
|
return pipe |
|
|
|
|
|
def show_anns(anns): |
|
if len(anns) == 0: |
|
return |
|
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True) |
|
full_img = None |
|
|
|
|
|
for i in range(len(sorted_anns)): |
|
ann = anns[i] |
|
m = ann["segmentation"] |
|
if full_img is None: |
|
full_img = np.zeros((m.shape[0], m.shape[1], 3)) |
|
map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16) |
|
map[m != 0] = i + 1 |
|
color_mask = np.random.random((1, 3)).tolist()[0] |
|
full_img[m != 0] = color_mask |
|
full_img = full_img * 255 |
|
|
|
res = np.zeros((map.shape[0], map.shape[1], 3)) |
|
res[:, :, 0] = map % 256 |
|
res[:, :, 1] = map // 256 |
|
res.astype(np.float32) |
|
full_img = Image.fromarray(np.uint8(full_img)) |
|
return full_img, res |
|
|
|
|
|
class EditAnythingLoraModel: |
|
def __init__( |
|
self, |
|
base_model_path="../chilloutmix_NiPrunedFp32Fix", |
|
lora_model_path="../40806/mix4", |
|
use_blip=True, |
|
blip_processor=None, |
|
blip_model=None, |
|
sam_generator=None, |
|
controlmodel_name="LAION Pretrained(v0-4)-SD15", |
|
|
|
extra_inpaint=True, |
|
tile_model=None, |
|
lora_weight=1.0, |
|
alpha_mixing=None, |
|
mask_predictor=None, |
|
): |
|
self.device = device |
|
self.use_blip = use_blip |
|
|
|
|
|
self.default_controlnet_path = config_dict[controlmodel_name] |
|
self.base_model_path = base_model_path |
|
self.lora_model_path = lora_model_path |
|
self.defalut_enable_all_generate = False |
|
self.extra_inpaint = extra_inpaint |
|
self.last_ref_infer = False |
|
self.pipe = obtain_generation_model( |
|
base_model_path, |
|
lora_model_path, |
|
self.default_controlnet_path, |
|
generation_only=False, |
|
extra_inpaint=extra_inpaint, |
|
lora_weight=lora_weight, |
|
) |
|
|
|
|
|
self.sam_generator, self.mask_predictor = init_sam_model( |
|
sam_generator, mask_predictor |
|
) |
|
|
|
if use_blip: |
|
if blip_processor is not None: |
|
self.blip_processor = blip_processor |
|
else: |
|
self.blip_processor = init_blip_processor() |
|
|
|
if blip_model is not None: |
|
self.blip_model = blip_model |
|
else: |
|
self.blip_model = init_blip_model() |
|
|
|
|
|
if tile_model is not None: |
|
self.tile_pipe = tile_model |
|
else: |
|
self.tile_pipe = obtain_tile_model( |
|
base_model_path, lora_model_path, lora_weight=lora_weight |
|
) |
|
|
|
def get_blip2_text(self, image): |
|
inputs = self.blip_processor(image, return_tensors="pt").to( |
|
self.device, torch.float16 |
|
) |
|
generated_ids = self.blip_model.generate(**inputs, max_new_tokens=50) |
|
generated_text = self.blip_processor.batch_decode( |
|
generated_ids, skip_special_tokens=True |
|
)[0].strip() |
|
return generated_text |
|
|
|
def get_sam_control(self, image): |
|
masks = self.sam_generator.generate(image) |
|
full_img, res = show_anns(masks) |
|
return full_img, res |
|
|
|
def get_click_mask(self, image, clicked_points): |
|
self.mask_predictor.set_image(image) |
|
|
|
points, labels = zip(*[(point[:2], point[2]) |
|
for point in clicked_points]) |
|
|
|
|
|
input_point = np.array(points) |
|
input_label = np.array(labels) |
|
|
|
masks, _, _ = self.mask_predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
multimask_output=False, |
|
) |
|
|
|
return masks |
|
|
|
@torch.inference_mode() |
|
def process_image_click( |
|
self, |
|
original_image: gr.Image, |
|
point_prompt: gr.Radio, |
|
clicked_points: gr.State, |
|
image_resolution, |
|
evt: gr.SelectData, |
|
): |
|
|
|
clicked_coords = evt.index |
|
x, y = clicked_coords |
|
label = point_prompt |
|
lab = 1 if label == "Foreground Point" else 0 |
|
clicked_points.append((x, y, lab)) |
|
|
|
input_image = np.array(original_image, dtype=np.uint8) |
|
H, W, C = input_image.shape |
|
input_image = HWC3(input_image) |
|
img = resize_image(input_image, image_resolution) |
|
|
|
|
|
resized_points = resize_points( |
|
clicked_points, input_image.shape, image_resolution |
|
) |
|
mask_click_np = self.get_click_mask(img, resized_points) |
|
|
|
|
|
mask_click_np = np.transpose(mask_click_np, (1, 2, 0)) * 255.0 |
|
|
|
mask_image = HWC3(mask_click_np.astype(np.uint8)) |
|
mask_image = cv2.resize(mask_image, (W, H), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
edited_image = input_image |
|
for x, y, lab in clicked_points: |
|
|
|
color = (255, 0, 0) if lab == 1 else (0, 0, 255) |
|
|
|
|
|
edited_image = cv2.circle(edited_image, (x, y), 20, color, -1) |
|
|
|
|
|
opacity_mask = 0.75 |
|
opacity_edited = 1.0 |
|
|
|
|
|
overlay_image = cv2.addWeighted( |
|
edited_image, |
|
opacity_edited, |
|
(mask_image * |
|
np.array([0 / 255, 255 / 255, 0 / 255])).astype(np.uint8), |
|
opacity_mask, |
|
0, |
|
) |
|
|
|
return ( |
|
Image.fromarray(overlay_image), |
|
clicked_points, |
|
Image.fromarray(mask_image), |
|
) |
|
|
|
@torch.inference_mode() |
|
def process( |
|
self, |
|
source_image, |
|
enable_all_generate, |
|
mask_image, |
|
control_scale, |
|
enable_auto_prompt, |
|
a_prompt, |
|
n_prompt, |
|
num_samples, |
|
image_resolution, |
|
detect_resolution, |
|
ddim_steps, |
|
guess_mode, |
|
scale, |
|
seed, |
|
eta, |
|
enable_tile=True, |
|
refine_alignment_ratio=None, |
|
refine_image_resolution=None, |
|
alpha_weight=0.5, |
|
use_scale_map=False, |
|
condition_model=None, |
|
ref_image=None, |
|
attention_auto_machine_weight=1.0, |
|
gn_auto_machine_weight=1.0, |
|
style_fidelity=0.5, |
|
reference_attn=True, |
|
reference_adain=True, |
|
ref_prompt=None, |
|
ref_sam_scale=None, |
|
ref_inpaint_scale=None, |
|
ref_auto_prompt=False, |
|
ref_textinv=True, |
|
ref_textinv_path=None, |
|
): |
|
|
|
if condition_model is None or condition_model == "EditAnything": |
|
this_controlnet_path = self.default_controlnet_path |
|
else: |
|
this_controlnet_path = condition_model |
|
input_image = ( |
|
source_image["image"] |
|
if isinstance(source_image, dict) |
|
else np.array(source_image, dtype=np.uint8) |
|
) |
|
if mask_image is None: |
|
if enable_all_generate != self.defalut_enable_all_generate: |
|
self.pipe = obtain_generation_model( |
|
self.base_model_path, |
|
self.lora_model_path, |
|
this_controlnet_path, |
|
enable_all_generate, |
|
self.extra_inpaint, |
|
) |
|
self.defalut_enable_all_generate = enable_all_generate |
|
if enable_all_generate: |
|
print( |
|
"source_image", |
|
source_image["mask"].shape, |
|
input_image.shape, |
|
) |
|
mask_image = ( |
|
np.ones((input_image.shape[0], |
|
input_image.shape[1], 3)) * 255 |
|
) |
|
else: |
|
mask_image = source_image["mask"] |
|
else: |
|
mask_image = np.array(mask_image, dtype=np.uint8) |
|
if self.default_controlnet_path != this_controlnet_path: |
|
print( |
|
"To Use:", |
|
this_controlnet_path, |
|
"Current:", |
|
self.default_controlnet_path, |
|
) |
|
print("Change condition model to:", this_controlnet_path) |
|
self.pipe = obtain_generation_model( |
|
self.base_model_path, |
|
self.lora_model_path, |
|
this_controlnet_path, |
|
enable_all_generate, |
|
self.extra_inpaint, |
|
) |
|
self.default_controlnet_path = this_controlnet_path |
|
torch.cuda.empty_cache() |
|
if self.last_ref_infer: |
|
print("Redefine the model to overwrite the ref mode") |
|
self.pipe = obtain_generation_model( |
|
self.base_model_path, |
|
self.lora_model_path, |
|
this_controlnet_path, |
|
enable_all_generate, |
|
self.extra_inpaint, |
|
) |
|
self.last_ref_infer = False |
|
|
|
if ref_image is not None: |
|
ref_mask = ref_image["mask"] |
|
ref_image = ref_image["image"] |
|
if ref_auto_prompt or ref_textinv: |
|
bbox = get_bounding_box( |
|
np.array(ref_mask) / 255 |
|
) |
|
cropped_ref_mask = ref_mask.crop( |
|
(bbox[0], bbox[1], bbox[2], bbox[3])) |
|
cropped_ref_image = ref_image.crop( |
|
(bbox[0], bbox[1], bbox[2], bbox[3])) |
|
|
|
cropped_ref_image = np.array(cropped_ref_image) * ( |
|
np.array(cropped_ref_mask)[:, :, :3] / 255.0 |
|
) |
|
cropped_ref_image = Image.fromarray( |
|
cropped_ref_image.astype("uint8")) |
|
|
|
if ref_auto_prompt: |
|
generated_prompt = self.get_blip2_text(cropped_ref_image) |
|
ref_prompt += generated_prompt |
|
a_prompt += generated_prompt |
|
print("Generated ref text:", ref_prompt) |
|
print("Generated input text:", a_prompt) |
|
self.last_ref_infer = True |
|
|
|
|
|
if ref_textinv: |
|
try: |
|
self.pipe.load_textual_inversion(ref_textinv_path) |
|
print("Load textinv embedding from:", ref_textinv_path) |
|
except: |
|
print("No textinvert embeddings found.") |
|
ref_data_path = "./utils/tmp/textinv/img" |
|
if not os.path.exists(ref_data_path): |
|
os.makedirs(ref_data_path) |
|
cropped_ref_image.save(os.path.join(ref_data_path, 'ref.png')) |
|
print("Ref image region is save to:", ref_data_path) |
|
print("Plese finetune with run_texutal_inversion.sh in utils folder to get the textinvert embeddings.") |
|
|
|
else: |
|
ref_mask = None |
|
|
|
with torch.no_grad(): |
|
if self.use_blip and enable_auto_prompt: |
|
print("Generating text:") |
|
blip2_prompt = self.get_blip2_text(input_image) |
|
print("Generated text:", blip2_prompt) |
|
if len(a_prompt) > 0: |
|
a_prompt = blip2_prompt + "," + a_prompt |
|
else: |
|
a_prompt = blip2_prompt |
|
|
|
input_image = HWC3(input_image) |
|
|
|
img = resize_image(input_image, image_resolution) |
|
H, W, C = img.shape |
|
|
|
print("Generating SAM seg:") |
|
|
|
full_segmask, detected_map = self.get_sam_control( |
|
resize_image(input_image, detect_resolution) |
|
) |
|
|
|
detected_map = HWC3(detected_map.astype(np.uint8)) |
|
detected_map = cv2.resize( |
|
detected_map, (W, H), interpolation=cv2.INTER_LINEAR |
|
) |
|
|
|
control = torch.from_numpy(detected_map.copy()).float().cuda() |
|
control = torch.stack([control for _ in range(num_samples)], dim=0) |
|
control = einops.rearrange(control, "b h w c -> b c h w").clone() |
|
|
|
mask_imag_ori = HWC3(mask_image.astype(np.uint8)) |
|
mask_image_tmp = cv2.resize( |
|
mask_imag_ori, (W, H), interpolation=cv2.INTER_LINEAR |
|
) |
|
mask_image = Image.fromarray(mask_image_tmp) |
|
|
|
if seed == -1: |
|
seed = random.randint(0, 65535) |
|
seed_everything(seed) |
|
generator = torch.manual_seed(seed) |
|
postive_prompt = a_prompt |
|
negative_prompt = n_prompt |
|
prompt_embeds, negative_prompt_embeds = get_pipeline_embeds( |
|
self.pipe, postive_prompt, negative_prompt, "cuda" |
|
) |
|
prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0) |
|
negative_prompt_embeds = torch.cat( |
|
[negative_prompt_embeds] * num_samples, dim=0 |
|
) |
|
|
|
if enable_all_generate and self.extra_inpaint: |
|
self.pipe.safety_checker = lambda images, clip_input: ( |
|
images, False) |
|
if ref_image is not None: |
|
print("Not support yet.") |
|
return |
|
x_samples = self.pipe( |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
num_images_per_prompt=num_samples, |
|
num_inference_steps=ddim_steps, |
|
generator=generator, |
|
height=H, |
|
width=W, |
|
image=[control.type(torch.float16)], |
|
controlnet_conditioning_scale=[float(control_scale)], |
|
guidance_scale=scale, |
|
guess_mode=guess_mode, |
|
).images |
|
else: |
|
multi_condition_image = [] |
|
multi_condition_scale = [] |
|
multi_condition_image.append(control.type(torch.float16)) |
|
multi_condition_scale.append(float(control_scale)) |
|
ref_multi_condition_scale = [] |
|
if ref_image is not None: |
|
ref_multi_condition_scale.append(float(ref_sam_scale)) |
|
if self.extra_inpaint: |
|
inpaint_image = make_inpaint_condition(img, mask_image_tmp) |
|
multi_condition_image.append( |
|
inpaint_image.type(torch.float16)) |
|
multi_condition_scale.append(1.0) |
|
if ref_image is not None: |
|
ref_multi_condition_scale.append( |
|
float(ref_inpaint_scale)) |
|
if use_scale_map: |
|
scale_map_tmp = source_image["mask"] |
|
tmp = HWC3(scale_map_tmp.astype(np.uint8)) |
|
scale_map_tmp = cv2.resize( |
|
tmp, (W, H), interpolation=cv2.INTER_LINEAR) |
|
scale_map_tmp = Image.fromarray(scale_map_tmp) |
|
controlnet_conditioning_scale_map = 1.0 - \ |
|
prepare_mask_image(scale_map_tmp).float() |
|
print('scale map:', controlnet_conditioning_scale_map.size()) |
|
else: |
|
controlnet_conditioning_scale_map = None |
|
|
|
if isinstance(self.pipe, StableDiffusionControlNetInpaintMixingPipeline): |
|
x_samples = self.pipe( |
|
image=img, |
|
mask_image=mask_image, |
|
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, |
|
num_images_per_prompt=num_samples, |
|
num_inference_steps=ddim_steps, |
|
generator=generator, |
|
controlnet_conditioning_image=multi_condition_image, |
|
height=H, |
|
width=W, |
|
controlnet_conditioning_scale=multi_condition_scale, |
|
guidance_scale=scale, |
|
alpha_weight=alpha_weight, |
|
controlnet_conditioning_scale_map=controlnet_conditioning_scale_map |
|
).images |
|
else: |
|
x_samples = self.pipe( |
|
image=img, |
|
mask_image=mask_image, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
num_images_per_prompt=num_samples, |
|
num_inference_steps=ddim_steps, |
|
generator=generator, |
|
controlnet_conditioning_image=multi_condition_image, |
|
height=H, |
|
width=W, |
|
controlnet_conditioning_scale=multi_condition_scale, |
|
guidance_scale=scale, |
|
ref_image=ref_image, |
|
ref_mask=ref_mask, |
|
ref_prompt=ref_prompt, |
|
attention_auto_machine_weight=attention_auto_machine_weight, |
|
gn_auto_machine_weight=gn_auto_machine_weight, |
|
style_fidelity=style_fidelity, |
|
reference_attn=reference_attn, |
|
reference_adain=reference_adain, |
|
ref_controlnet_conditioning_scale=ref_multi_condition_scale, |
|
guess_mode=guess_mode, |
|
).images |
|
results = [x_samples[i] for i in range(num_samples)] |
|
|
|
results_tile = [] |
|
if enable_tile: |
|
prompt_embeds, negative_prompt_embeds = get_pipeline_embeds( |
|
self.tile_pipe, postive_prompt, negative_prompt, "cuda" |
|
) |
|
for i in range(num_samples): |
|
img_tile = PIL.Image.fromarray( |
|
resize_image( |
|
np.array(x_samples[i]), refine_image_resolution) |
|
) |
|
if i == 0: |
|
mask_image_tile = cv2.resize( |
|
mask_imag_ori, |
|
(img_tile.size[0], img_tile.size[1]), |
|
interpolation=cv2.INTER_LINEAR, |
|
) |
|
mask_image_tile = Image.fromarray(mask_image_tile) |
|
if isinstance(self.pipe, StableDiffusionControlNetInpaintMixingPipeline): |
|
x_samples_tile = self.tile_pipe( |
|
image=img_tile, |
|
mask_image=mask_image_tile, |
|
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, |
|
num_images_per_prompt=1, |
|
num_inference_steps=ddim_steps, |
|
generator=generator, |
|
controlnet_conditioning_image=img_tile, |
|
height=img_tile.size[1], |
|
width=img_tile.size[0], |
|
controlnet_conditioning_scale=1.0, |
|
alignment_ratio=refine_alignment_ratio, |
|
guidance_scale=scale, |
|
alpha_weight=alpha_weight, |
|
controlnet_conditioning_scale_map=controlnet_conditioning_scale_map |
|
).images |
|
else: |
|
x_samples_tile = self.tile_pipe( |
|
image=img_tile, |
|
mask_image=mask_image_tile, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
num_images_per_prompt=1, |
|
num_inference_steps=ddim_steps, |
|
generator=generator, |
|
controlnet_conditioning_image=img_tile, |
|
height=img_tile.size[1], |
|
width=img_tile.size[0], |
|
controlnet_conditioning_scale=1.0, |
|
alignment_ratio=refine_alignment_ratio, |
|
guidance_scale=scale, |
|
guess_mode=guess_mode, |
|
).images |
|
results_tile += x_samples_tile |
|
|
|
return results_tile, results, [full_segmask, mask_image], postive_prompt |
|
|
|
def download_image(url): |
|
response = requests.get(url) |
|
return Image.open(BytesIO(response.content)).convert("RGB") |
|
|