|
|
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
|
|
def resize_image(image)->Image.Image: |
|
pixel_number = 960*960 |
|
granularity_val = 64 |
|
ratio = image.size[0] / image.size[1] |
|
width = int((pixel_number * ratio) ** 0.5) |
|
width = width - (width % granularity_val) |
|
height = int(pixel_number / width) |
|
height = height - (height % granularity_val) |
|
return image.resize((width, height)) |
|
|
|
def get_masked_background_image(image, image_mask)->tuple: |
|
image_mask_pil = image_mask.resize(image.size) |
|
image = np.array(image.convert("RGB")).transpose(2, 0, 1).astype(np.float32) / 255.0 |
|
image_mask = np.array(image_mask_pil.convert("L")).astype(np.float32) / 255.0 |
|
image[:,image_mask < 0.5] = 0 |
|
return image, image_mask |
|
|
|
def get_control_image_tensor(vae, image, mask)->torch.Tensor: |
|
masked_image, image_mask = get_masked_background_image(image, mask) |
|
masked_image_tensor = torch.from_numpy(masked_image) |
|
masked_image_tensor = (masked_image_tensor - 0.5) / 0.5 |
|
masked_image_tensor = masked_image_tensor.unsqueeze(0).to(device="cuda:0") |
|
|
|
control_latents = vae.encode( |
|
masked_image_tensor[:, :3, :, :].to(vae.dtype) |
|
).latent_dist.sample() |
|
control_latents = control_latents * vae.config.scaling_factor |
|
|
|
mask_tensor = torch.tensor(image_mask, dtype=torch.float32)[None, None, ...].to(device="cuda:0") |
|
mask_tensor = torch.where(mask_tensor > 0.5, 1.0, 0) |
|
mask_resized = torch.nn.functional.interpolate(mask_tensor, size=(control_latents.shape[2], control_latents.shape[3]), mode='nearest') |
|
control_tensor = torch.cat([control_latents, mask_resized], dim=1) |
|
return control_tensor |
|
|
|
def remove_bg_from_image(image)->Image.Image: |
|
from transformers import pipeline |
|
pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True) |
|
mask = pipe(image, return_mask = True) |
|
return mask |
|
|
|
def paste_fg_over_image(gen_image: Image.Image, orig_image: Image.Image, fg_mask: Image.Image)->Image.Image: |
|
fg_mask = fg_mask.convert("L") |
|
fg_mask = fg_mask.resize(orig_image.size, Image.NEAREST) |
|
gen_image = gen_image.convert("RGBA") |
|
orig_image = orig_image.convert("RGBA") |
|
gen_image.paste(orig_image, (0, 0), fg_mask) |
|
return gen_image.convert("RGB") |