|
import os |
|
import argparse |
|
from tqdm.auto import tqdm |
|
from packaging import version |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from torchvision import transforms |
|
from diffusers import ( |
|
AutoencoderKL, |
|
ControlNetModel, |
|
DDPMScheduler, |
|
StableDiffusionControlNetPipeline, |
|
UNet2DConditionModel, |
|
UniPCMultistepScheduler, |
|
PNDMScheduler, |
|
AmusedInpaintPipeline, AmusedScheduler, VQModel, UVit2DModel |
|
|
|
) |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from diffusers.utils import load_image |
|
from transformers import AutoTokenizer, CLIPFeatureExtractor, PretrainedConfig |
|
from PIL import Image |
|
from utils.mclip import * |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Edit images with M3Face.") |
|
parser.add_argument( |
|
"--prompt", |
|
type=str, |
|
default="This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup.", |
|
help="The input text prompt for image generation." |
|
) |
|
parser.add_argument( |
|
"--condition", |
|
type=str, |
|
default="mask", |
|
choices=["mask", "landmark"], |
|
help="Use segmentation mask or facial landmarks for image generation." |
|
) |
|
parser.add_argument( |
|
"--image_path", |
|
type=str, |
|
default=None, |
|
help="Path to the input image." |
|
) |
|
parser.add_argument( |
|
"--condition_path", |
|
type=str, |
|
default=None, |
|
help="Path to the original mask/landmark image." |
|
) |
|
parser.add_argument( |
|
"--edit_condition_path", |
|
type=str, |
|
default=None, |
|
help="Path to the target mask/landmark image." |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default='output/', |
|
help="The output directory where the results will be written.", |
|
) |
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible generation.") |
|
parser.add_argument( |
|
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." |
|
) |
|
parser.add_argument("--edit_condition", action="store_true") |
|
parser.add_argument("--load_unet_from_local", action="store_true") |
|
parser.add_argument("--save_unet", action="store_true") |
|
parser.add_argument("--unet_local_path", type=str, default=None) |
|
parser.add_argument("--load_finetune_from_local", action="store_true") |
|
parser.add_argument("--finetune_path", type=str, default=None) |
|
parser.add_argument("--use_english", action="store_true", help="Use the English models.") |
|
parser.add_argument("--embedding_optimize_it", type=int, default=500) |
|
parser.add_argument("--model_finetune_it", type=int, default=1000) |
|
parser.add_argument("--alpha", nargs="+", type=float, default=[0.8, 0.9, 1, 1.1]) |
|
parser.add_argument("--num_inference_steps", nargs="+", type=int, default=[20, 40, 50]) |
|
parser.add_argument("--unet_layer", type=str, default="2and3", |
|
help="Which UNet layers in the SD to finetune.") |
|
|
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
def get_muse(args): |
|
muse_model_name = 'm3face/FaceConditioning' |
|
if args.condition == 'mask': |
|
muse_revision = 'segmentation' |
|
elif args.condition == 'landmark': |
|
muse_revision = 'landmark' |
|
scheduler = AmusedScheduler.from_pretrained(muse_model_name, revision=muse_revision, subfolder='scheduler') |
|
vqvae = VQModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='vqvae') |
|
uvit2 = UVit2DModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='transformer') |
|
text_encoder = MultilingualCLIP.from_pretrained(muse_model_name, revision=muse_revision, subfolder='text_encoder') |
|
tokenizer = AutoTokenizer.from_pretrained(muse_model_name, revision=muse_revision, subfolder='tokenizer') |
|
|
|
pipeline = AmusedInpaintPipeline( |
|
vqvae=vqvae, |
|
tokenizer=tokenizer, |
|
text_encoder=text_encoder, |
|
transformer=uvit2, |
|
scheduler=scheduler |
|
).to("cuda") |
|
|
|
return pipeline |
|
|
|
def import_model_class_from_model_name(sd_model_name): |
|
text_encoder_config = PretrainedConfig.from_pretrained( |
|
sd_model_name, |
|
subfolder="text_encoder", |
|
) |
|
model_class = text_encoder_config.architectures[0] |
|
|
|
if model_class == "CLIPTextModel": |
|
from transformers import CLIPTextModel |
|
|
|
return CLIPTextModel |
|
elif model_class == "RobertaSeriesModelWithTransformation": |
|
from diffusers.pipelines.deprecated.alt_diffusion import RobertaSeriesModelWithTransformation |
|
|
|
return RobertaSeriesModelWithTransformation |
|
else: |
|
raise ValueError(f"{model_class} is not supported.") |
|
|
|
def preprocess(image, condition, prompt, tokenizer): |
|
image_transforms = transforms.Compose( |
|
[ |
|
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), |
|
transforms.CenterCrop(512), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]), |
|
] |
|
) |
|
condition_transforms = transforms.Compose( |
|
[ |
|
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), |
|
transforms.CenterCrop(512), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
image = image_transforms(image) |
|
condition = condition_transforms(condition) |
|
inputs = tokenizer( |
|
[prompt], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" |
|
) |
|
|
|
return image, condition, inputs.input_ids, inputs.attention_mask |
|
|
|
def main(args): |
|
if args.use_english: |
|
sd_model_name = 'runwayml/stable-diffusion-v1-5' |
|
controlnet_model_name = 'm3face/FaceControlNet' |
|
if args.condition == 'mask': |
|
controlnet_revision = 'segmentation-english' |
|
elif args.condition == 'landmark': |
|
controlnet_revision = 'landmark-english' |
|
else: |
|
sd_model_name = 'BAAI/AltDiffusion-m18' |
|
controlnet_model_name = 'm3face/FaceControlNet' |
|
if args.condition == 'mask': |
|
controlnet_revision = 'segmentation-mlin' |
|
elif args.condition == 'landmark': |
|
controlnet_revision = 'landmark-mlin' |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(sd_model_name, subfolder="vae") |
|
tokenizer = AutoTokenizer.from_pretrained(sd_model_name, subfolder="tokenizer", use_fast=False) |
|
text_encoder_cls = import_model_class_from_model_name(sd_model_name) |
|
text_encoder = text_encoder_cls.from_pretrained(sd_model_name, subfolder="text_encoder") |
|
noise_scheduler = DDPMScheduler.from_pretrained(sd_model_name, subfolder="scheduler") |
|
|
|
if args.load_unet_from_local: |
|
unet = UNet2DConditionModel.from_pretrained(args.unet_local_path) |
|
else: |
|
unet = UNet2DConditionModel.from_pretrained(sd_model_name, subfolder="unet") |
|
|
|
controlnet = ControlNetModel.from_pretrained(controlnet_model_name, revision=controlnet_revision) |
|
|
|
if args.edit_condition: |
|
muse = get_muse(args) |
|
|
|
vae.requires_grad_(False) |
|
text_encoder.requires_grad_(False) |
|
controlnet.requires_grad_(False) |
|
unet.requires_grad_(False) |
|
vae.eval() |
|
text_encoder.eval() |
|
controlnet.eval() |
|
unet.eval() |
|
|
|
if args.enable_xformers_memory_efficient_attention: |
|
if is_xformers_available(): |
|
import xformers |
|
|
|
xformers_version = version.parse(xformers.__version__) |
|
if xformers_version == version.parse("0.0.16"): |
|
print( |
|
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." |
|
) |
|
unet.enable_xformers_memory_efficient_attention() |
|
controlnet.enable_xformers_memory_efficient_attention() |
|
else: |
|
raise ValueError("xformers is not available. Make sure it is installed correctly") |
|
|
|
|
|
params = [] |
|
for name, param in unet.named_parameters(): |
|
if(name.startswith('up_blocks')): |
|
params.append(param) |
|
|
|
if args.unet_layer == 'only1': |
|
params_to_optimize = [ |
|
{'params': params[38:154]}, |
|
] |
|
elif args.unet_layer == 'only2': |
|
params_to_optimize = [ |
|
{'params': params[154:270]}, |
|
] |
|
elif args.unet_layer == 'only3': |
|
params_to_optimize = [ |
|
{'params': params[270:]}, |
|
] |
|
elif args.unet_layer == '1and2': |
|
params_to_optimize = [ |
|
{'params': params[38:270]}, |
|
] |
|
elif args.unet_layer == '2and3': |
|
params_to_optimize = [ |
|
{'params': params[154:]}, |
|
] |
|
elif args.unet_layer == 'all': |
|
params_to_optimize = [ |
|
{'params': params}, |
|
] |
|
|
|
image = Image.open(args.image_path).convert('RGB') |
|
condition = Image.open(args.condition_path).convert('RGB') |
|
image, condition, input_ids, attention_mask = preprocess(image, condition, args.prompt, tokenizer) |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
vae.to(device, dtype=torch.float32) |
|
unet.to(device, dtype=torch.float32) |
|
text_encoder.to(device, dtype=torch.float32) |
|
controlnet.to(device) |
|
image = image.to(device).unsqueeze(0) |
|
condition = condition.to(device).unsqueeze(0) |
|
input_ids = input_ids.to(device) |
|
attention_mask = attention_mask.to(device) |
|
|
|
|
|
if args.load_finetune_from_local: |
|
print('Loading embeddings from local ...') |
|
orig_emb = torch.load(os.path.join(args.finetune_path, 'orig_emb.pt')) |
|
emb = torch.load(os.path.join(args.finetune_path, 'emb.pt')) |
|
else: |
|
init_latent = vae.encode(image.to(dtype=torch.float32)).latent_dist.sample() |
|
init_latent = init_latent * vae.config.scaling_factor |
|
|
|
if not args.use_english: |
|
orig_emb = text_encoder(input_ids, attention_mask=attention_mask)[0] |
|
else: |
|
orig_emb = text_encoder(input_ids)[0] |
|
emb = orig_emb.clone() |
|
torch.save(orig_emb, os.path.join(args.output_dir, 'orig_emb.pt')) |
|
torch.save(emb, os.path.join(args.output_dir, 'emb.pt')) |
|
|
|
|
|
print('1. Optimize the embedding') |
|
unet.eval() |
|
emb.requires_grad = True |
|
lr = 0.001 |
|
it = args.embedding_optimize_it |
|
opt = torch.optim.Adam([emb], lr=lr) |
|
history = [] |
|
|
|
pbar = tqdm( |
|
range(it), |
|
initial=0, |
|
desc="Optimize Steps", |
|
) |
|
global_step = 0 |
|
|
|
for i in pbar: |
|
opt.zero_grad() |
|
|
|
noise = torch.randn_like(init_latent) |
|
bsz = init_latent.shape[0] |
|
t_enc = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latent.device) |
|
t_enc = t_enc.long() |
|
z = noise_scheduler.add_noise(init_latent, noise, t_enc) |
|
|
|
controlnet_image = condition.to(dtype=torch.float32) |
|
|
|
down_block_res_samples, mid_block_res_sample = controlnet( |
|
z, |
|
t_enc, |
|
encoder_hidden_states=emb, |
|
controlnet_cond=controlnet_image, |
|
return_dict=False, |
|
) |
|
|
|
|
|
pred_noise = unet( |
|
z, |
|
t_enc, |
|
encoder_hidden_states=emb, |
|
down_block_additional_residuals=[ |
|
sample.to(dtype=torch.float32) for sample in down_block_res_samples |
|
], |
|
mid_block_additional_residual=mid_block_res_sample.to(dtype=torch.float32), |
|
).sample |
|
|
|
|
|
if noise_scheduler.config.prediction_type == "epsilon": |
|
target = noise |
|
elif noise_scheduler.config.prediction_type == "v_prediction": |
|
target = noise_scheduler.get_velocity(init_latent, noise, t_enc) |
|
else: |
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
|
loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean") |
|
|
|
loss.backward() |
|
global_step += 1 |
|
pbar.set_postfix({"loss": loss.item()}) |
|
history.append(loss.item()) |
|
opt.step() |
|
opt.zero_grad() |
|
|
|
|
|
print('2. Finetune the model') |
|
emb.requires_grad = False |
|
unet.requires_grad_(True) |
|
unet.train() |
|
|
|
lr = 5e-5 |
|
it = args.model_finetune_it |
|
opt = torch.optim.Adam(params_to_optimize, lr=lr) |
|
history = [] |
|
|
|
pbar = tqdm( |
|
range(it), |
|
initial=0, |
|
desc="Finetune Steps", |
|
) |
|
global_step = 0 |
|
for i in pbar: |
|
opt.zero_grad() |
|
|
|
noise = torch.randn_like(init_latent) |
|
bsz = init_latent.shape[0] |
|
t_enc = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latent.device) |
|
t_enc = t_enc.long() |
|
z = noise_scheduler.add_noise(init_latent, noise, t_enc) |
|
|
|
controlnet_image = condition.to(dtype=torch.float32) |
|
|
|
down_block_res_samples, mid_block_res_sample = controlnet( |
|
z, |
|
t_enc, |
|
encoder_hidden_states=emb, |
|
controlnet_cond=controlnet_image, |
|
return_dict=False, |
|
) |
|
|
|
|
|
pred_noise = unet( |
|
z, |
|
t_enc, |
|
encoder_hidden_states=emb, |
|
down_block_additional_residuals=[ |
|
sample.to(dtype=torch.float32) for sample in down_block_res_samples |
|
], |
|
mid_block_additional_residual=mid_block_res_sample.to(dtype=torch.float32), |
|
).sample |
|
|
|
|
|
if noise_scheduler.config.prediction_type == "epsilon": |
|
target = noise |
|
elif noise_scheduler.config.prediction_type == "v_prediction": |
|
target = noise_scheduler.get_velocity(init_latent, noise, t_enc) |
|
else: |
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
|
loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean") |
|
|
|
loss.backward() |
|
global_step += 1 |
|
pbar.set_postfix({"loss": loss.item()}) |
|
history.append(loss.item()) |
|
opt.step() |
|
opt.zero_grad() |
|
|
|
|
|
print("3. Generating images... ") |
|
|
|
unet.eval() |
|
controlnet.eval() |
|
|
|
if args.edit_condition_path is None: |
|
edit_condition = load_image(args.condition_path) |
|
else: |
|
edit_condition = load_image(args.edit_condition_path) |
|
if args.edit_condition: |
|
edit_mask = Image.new("L", (256, 256), 0) |
|
for i in range(256): |
|
for j in range(256): |
|
if 40 < i < 220 and 20 < j < 256: |
|
edit_mask.putpixel((i, j), 256) |
|
|
|
if args.condition == 'mask': |
|
condition = 'segmentation' |
|
elif args.condition == 'landmark': |
|
condition = 'landmark' |
|
edit_prompt = f"Generate face {condition} | " + args.prompt |
|
input_image = edit_condition.resize((256, 256)).convert("RGB") |
|
edit_condition = muse(edit_prompt, input_image, edit_mask, num_inference_steps=30).images[0].resize((512, 512)) |
|
edit_condition.save(f'{args.output_dir}/edited_condition.png') |
|
|
|
|
|
del muse |
|
torch.cuda.empty_cache() |
|
|
|
if sd_model_name.startswith('BAAI'): |
|
scheduler = PNDMScheduler.from_pretrained( |
|
sd_model_name, |
|
subfolder='scheduler', |
|
) |
|
scheduler = UniPCMultistepScheduler.from_config(scheduler.config) |
|
feature_extractor = CLIPFeatureExtractor.from_pretrained( |
|
sd_model_name, |
|
subfolder='feature_extractor', |
|
) |
|
pipeline = StableDiffusionControlNetPipeline( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
controlnet=controlnet, |
|
scheduler=scheduler, |
|
safety_checker=None, |
|
feature_extractor=feature_extractor |
|
) |
|
else: |
|
pipeline = StableDiffusionControlNetPipeline.from_pretrained( |
|
sd_model_name, |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
controlnet=controlnet, |
|
safety_checker=None, |
|
) |
|
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) |
|
pipeline = pipeline.to(device) |
|
pipeline.set_progress_bar_config(disable=True) |
|
|
|
if args.enable_xformers_memory_efficient_attention: |
|
pipeline.enable_xformers_memory_efficient_attention() |
|
|
|
if args.seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=device).manual_seed(args.seed) |
|
|
|
with torch.autocast("cuda"): |
|
image = pipeline( |
|
image=edit_condition, prompt_embeds=emb, num_inference_steps=20, generator=generator |
|
).images[0] |
|
image.save(f'{args.output_dir}/reconstruct.png') |
|
|
|
|
|
for num_inference_steps in args.num_inference_steps: |
|
for alpha in args.alpha: |
|
new_emb = alpha * orig_emb + (1 - alpha) * emb |
|
|
|
with torch.autocast("cuda"): |
|
image = pipeline( |
|
image=edit_condition, prompt_embeds=new_emb, num_inference_steps=num_inference_steps, generator=generator |
|
).images[0] |
|
image.save(f'{args.output_dir}/image_{num_inference_steps}_{alpha}.png') |
|
|
|
if args.save_unet: |
|
print('Saving the unet model...') |
|
unet.save_pretrained(f'{args.output_dir}/unet') |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
main(args) |