Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
from diffusers import DDIMScheduler,DiffusionPipeline | |
import torch.nn.functional as F | |
import cv2 | |
from torchvision.utils import save_image | |
from diffusers.utils import load_image | |
from torchvision.transforms.functional import to_tensor, gaussian_blur | |
from matplotlib import pyplot as plt | |
import gradio as gr | |
import spaces | |
from gradio_imageslider import ImageSlider | |
from torchvision.transforms.functional import to_pil_image, to_tensor | |
from PIL import ImageFilter, Image | |
import traceback | |
def preprocess_image(input_image, device): | |
image = to_tensor(input_image) | |
image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1] | |
if image.shape[1] != 3: | |
image = image.expand(-1, 3, -1, -1) | |
image = F.interpolate(image, (1024, 1024)) | |
image = image.to(dtype).to(device) | |
return image | |
def load_description(fp): | |
with open(fp, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
def preprocess_mask(input_mask, device): | |
# Split the channels | |
r, g, b, alpha = input_mask.split() | |
# Create a new image where: | |
# - Black areas (where RGB = 0) become white (255). | |
# - Transparent areas (where alpha = 0) become black (0). | |
new_mask = Image.new("L", input_mask.size) | |
for x in range(input_mask.width): | |
for y in range(input_mask.height): | |
if alpha.getpixel((x, y)) == 0: # Transparent pixel | |
new_mask.putpixel((x, y), 0) # Set to black | |
else: # Non-transparent pixel (originally black in the mask) | |
new_mask.putpixel((x, y), 255) # Set to white | |
mask = to_tensor(new_mask.convert('L')) | |
mask = mask.unsqueeze_(0).float() # 0 or 1 | |
mask = F.interpolate(mask, (1024, 1024)) | |
mask = gaussian_blur(mask, kernel_size=(77, 77)) | |
mask[mask < 0.1] = 0 | |
mask[mask >= 0.1] = 1 | |
mask = mask.to(dtype).to(device) | |
return mask | |
def make_redder(img, mask, increase_factor=0.4): | |
img_redder = img.clone() | |
mask_expanded = mask.expand_as(img) | |
img_redder[0][mask_expanded[0] == 1] = torch.clamp(img_redder[0][mask_expanded[0] == 1] + increase_factor, 0, 1) | |
return img_redder | |
# Model loading parameters | |
is_cpu_offload_enabled = False | |
is_attention_slicing_enabled = True | |
# Load model | |
dtype = torch.float16 | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) | |
model_path = "stabilityai/stable-diffusion-xl-base-1.0" | |
pipeline = DiffusionPipeline.from_pretrained( | |
model_path, | |
custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser.py", | |
scheduler=scheduler, | |
variant="fp16", | |
use_safetensors=True, | |
torch_dtype=dtype, | |
).to(device) | |
if is_attention_slicing_enabled: | |
pipeline.enable_attention_slicing() | |
if is_cpu_offload_enabled: | |
pipeline.enable_model_cpu_offload() | |
def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, strength=0.8, similarity_suppression_steps=9, similarity_suppression_scale=0.3): | |
try: | |
generator = torch.Generator('cuda').manual_seed(seed) | |
prompt = "" # Set prompt to null | |
source_image_pure = gradio_image["background"] | |
mask_image_pure = gradio_image["layers"][0] | |
source_image = preprocess_image(source_image_pure.convert('RGB'), device) | |
mask = preprocess_mask(mask_image_pure, device) | |
START_STEP = 0 # AAS start step | |
END_STEP = int(strength * num_inference_steps) # AAS end step | |
LAYER = 34 # 0~23down,24~33mid,34~69up /AAS start layer | |
END_LAYER = 70 # AAS end layer | |
ss_steps = similarity_suppression_steps # similarity suppression steps | |
ss_scale = similarity_suppression_scale # similarity suppression scale | |
image = pipeline( | |
prompt=prompt, | |
image=source_image, | |
mask_image=mask, | |
height=1024, | |
width=1024, | |
AAS=True, # enable AAS | |
strength=strength, # inpainting strength | |
rm_guidance_scale=rm_guidance_scale, # removal guidance scale | |
ss_steps = ss_steps, # similarity suppression steps | |
ss_scale = ss_scale, # similarity suppression scale | |
AAS_start_step=START_STEP, # AAS start step | |
AAS_start_layer=LAYER, # AAS start layer | |
AAS_end_layer=END_LAYER, # AAS end layer | |
num_inference_steps=num_inference_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps) | |
generator=generator, | |
guidance_scale=1 | |
).images[0] | |
print('Inferece: DONE.') | |
pil_mask = to_pil_image(mask.squeeze(0)) | |
pil_mask_blurred = pil_mask.filter(ImageFilter.GaussianBlur(radius=15)) | |
mask_blurred = to_tensor(pil_mask_blurred).unsqueeze_(0).to(mask.device) | |
mask_f = 1-(1 - mask) * (1 - mask_blurred) | |
# image_1 = image.unsqueeze(0) | |
return source_image_pure, pil_mask, image | |
except: | |
print(traceback.format_exc()) | |
title = """<h1 align="center">Object Remove</h1>""" | |
with gr.Blocks() as demo: | |
gr.HTML(load_description("assets/title.md")) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("Advanced Options", open=False): | |
guidance_scale = gr.Slider( | |
minimum=1, | |
maximum=20, | |
value=9, | |
step=0.1, | |
label="Guidance Scale" | |
) | |
num_steps = gr.Slider( | |
minimum=5, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Steps" | |
) | |
seed = gr.Slider( | |
minimum=42, | |
maximum=999999, | |
value=42, | |
step=1, | |
label="Seed" | |
) | |
strength = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.8, | |
step=0.1, | |
label="Strength" | |
) | |
similarity_suppression_steps = gr.Slider( | |
minimum=0, | |
maximum=10, | |
value=9, | |
step=1, | |
label="Similarity Suppression Steps" | |
) | |
similarity_suppression_scale = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.3, | |
step=0.1, | |
label="Similarity Suppression Scale" | |
) | |
input_image = gr.ImageMask( | |
type="pil", label="Input Image",crop_size=(1200,1200), layers=False | |
) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
run_button = gr.Button("Generate") | |
result = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto") | |
run_button.click( | |
fn=remove, | |
inputs=[input_image, guidance_scale, num_steps, seed, strength, similarity_suppression_steps, similarity_suppression_scale], | |
outputs=result, | |
) | |
demo.queue(max_size=12).launch(share=False) | |