import os import math import gradio as gr import numpy as np import torch import safetensors.torch as sf import db_examples import datetime from pathlib import Path from PIL import Image from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler from diffusers.models.attention_processor import AttnProcessor2_0 from transformers import CLIPTextModel, CLIPTokenizer from briarmbg import BriaRMBG from enum import Enum from torch.hub import download_url_to_file import cv2 from typing import Optional from Depth.depth_anything_v2.dpt import DepthAnythingV2 # from FLORENCE import spaces import supervision as sv import torch from PIL import Image from utils.florence import load_florence_model, run_florence_inference, \ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK from utils.sam import load_sam_image_model, run_sam_inference import torch DEVICE = torch.device("cuda") torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) @spaces.GPU(duration=20) @torch.inference_mode() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def process_image(image_input, text_input) -> Optional[Image.Image]: # if not image_input: # gr.Info("Please upload an image.") # return None # if not text_input: # gr.Info("Please enter a text prompt.") # return None _, result = run_florence_inference( model=FLORENCE_MODEL, processor=FLORENCE_PROCESSOR, device=DEVICE, image=image_input, task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, text=text_input ) detections = sv.Detections.from_lmm( lmm=sv.LMM.FLORENCE_2, result=result, resolution_wh=image_input.size ) detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) if len(detections) == 0: gr.Info("No objects detected.") return None return Image.fromarray(detections.mask[0].astype("uint8") * 255) try: import xformers import xformers.ops XFORMERS_AVAILABLE = True print("xformers is available - Using memory efficient attention") except ImportError: XFORMERS_AVAILABLE = False print("xformers not available - Using default attention") # 'stablediffusionapi/realistic-vision-v51' # 'runwayml/stable-diffusion-v1-5' sd15_name = 'stablediffusionapi/realistic-vision-v51' tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet") rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4") # Change UNet with torch.no_grad(): new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) new_conv_in.weight.zero_() new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) new_conv_in.bias = unet.conv_in.bias unet.conv_in = new_conv_in unet_original_forward = unet.forward def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample) c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) new_sample = torch.cat([sample, c_concat], dim=1) kwargs['cross_attention_kwargs'] = {} return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) unet.forward = hooked_unet_forward # Load model_path = './models/iclight_sd15_fbc.safetensors' if not os.path.exists(model_path): download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fbc.safetensors', dst=model_path) # Device and dtype setup device = torch.device('cuda') dtype = torch.float16 # RTX 2070 works well with float16 # Memory optimizations for RTX 2070 torch.backends.cudnn.benchmark = True if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Set a smaller attention slice size for RTX 2070 torch.backends.cuda.max_split_size_mb = 512 # Move models to device with consistent dtype text_encoder = text_encoder.to(device=device, dtype=dtype) vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16 unet = unet.to(device=device, dtype=dtype) rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32 model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384]) model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device)) model.eval() from utils.florence import load_florence_model, run_florence_inference, \ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK from utils.sam import load_sam_image_model, run_sam_inference torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True #FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=device) SAM_IMAGE_MODEL = load_sam_image_model(device=device) # Update the state dict merging to use correct dtype sd_offset = sf.load_file(model_path) sd_origin = unet.state_dict() sd_merged = {k: sd_origin[k] + sd_offset[k].to(device=device, dtype=dtype) for k in sd_origin.keys()} unet.load_state_dict(sd_merged, strict=True) del sd_offset, sd_origin, sd_merged def enable_efficient_attention(): if XFORMERS_AVAILABLE: try: # RTX 2070 specific settings unet.set_use_memory_efficient_attention_xformers(True) vae.set_use_memory_efficient_attention_xformers(True) print("Enabled xformers memory efficient attention") except Exception as e: print(f"Xformers error: {e}") print("Falling back to sliced attention") # Use sliced attention for RTX 2070 unet.set_attention_slice_size(4) vae.set_attention_slice_size(4) unet.set_attn_processor(AttnProcessor2_0()) vae.set_attn_processor(AttnProcessor2_0()) else: # Fallback for when xformers is not available print("Using sliced attention") unet.set_attention_slice_size(4) vae.set_attention_slice_size(4) unet.set_attn_processor(AttnProcessor2_0()) vae.set_attn_processor(AttnProcessor2_0()) # Add memory clearing function def clear_memory(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() # Enable efficient attention enable_efficient_attention() # Samplers ddim_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1, ) euler_a_scheduler = EulerAncestralDiscreteScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, steps_offset=1 ) dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True, steps_offset=1 ) # Pipelines t2i_pipe = StableDiffusionPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=dpmpp_2m_sde_karras_scheduler, safety_checker=None, requires_safety_checker=False, feature_extractor=None, image_encoder=None ) i2i_pipe = StableDiffusionImg2ImgPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=dpmpp_2m_sde_karras_scheduler, safety_checker=None, requires_safety_checker=False, feature_extractor=None, image_encoder=None ) @torch.inference_mode() def encode_prompt_inner(txt: str): max_length = tokenizer.model_max_length chunk_length = tokenizer.model_max_length - 2 id_start = tokenizer.bos_token_id id_end = tokenizer.eos_token_id id_pad = id_end def pad(x, p, i): return x[:i] if len(x) >= i else x + [p] * (i - len(x)) tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"] chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)] chunks = [pad(ck, id_pad, max_length) for ck in chunks] token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64) conds = text_encoder(token_ids).last_hidden_state return conds @torch.inference_mode() def encode_prompt_pair(positive_prompt, negative_prompt): c = encode_prompt_inner(positive_prompt) uc = encode_prompt_inner(negative_prompt) c_len = float(len(c)) uc_len = float(len(uc)) max_count = max(c_len, uc_len) c_repeat = int(math.ceil(max_count / c_len)) uc_repeat = int(math.ceil(max_count / uc_len)) max_chunk = max(len(c), len(uc)) c = torch.cat([c] * c_repeat, dim=0)[:max_chunk] uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk] c = torch.cat([p[None, ...] for p in c], dim=1) uc = torch.cat([p[None, ...] for p in uc], dim=1) return c, uc @torch.inference_mode() def pytorch2numpy(imgs, quant=True): results = [] for x in imgs: y = x.movedim(0, -1) if quant: y = y * 127.5 + 127.5 y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) else: y = y * 0.5 + 0.5 y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32) results.append(y) return results @torch.inference_mode() def numpy2pytorch(imgs): h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0 h = h.movedim(-1, 1) return h def resize_and_center_crop(image, target_width, target_height): pil_image = Image.fromarray(image) original_width, original_height = pil_image.size scale_factor = max(target_width / original_width, target_height / original_height) resized_width = int(round(original_width * scale_factor)) resized_height = int(round(original_height * scale_factor)) resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) left = (resized_width - target_width) / 2 top = (resized_height - target_height) / 2 right = (resized_width + target_width) / 2 bottom = (resized_height + target_height) / 2 cropped_image = resized_image.crop((left, top, right, bottom)) return np.array(cropped_image) def resize_without_crop(image, target_width, target_height): pil_image = Image.fromarray(image) resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) return np.array(resized_image) @torch.inference_mode() def run_rmbg(img, sigma=0.0): # Convert RGBA to RGB if needed if img.shape[-1] == 4: # Use white background for alpha composition alpha = img[..., 3:] / 255.0 rgb = img[..., :3] white_bg = np.ones_like(rgb) * 255 img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8) H, W, C = img.shape assert C == 3 k = (256.0 / float(H * W)) ** 0.5 feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k))) feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32) alpha = rmbg(feed)[0][0] alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear") alpha = alpha.movedim(1, -1)[0] alpha = alpha.detach().float().cpu().numpy().clip(0, 1) # Create RGBA image rgba = np.dstack((img, alpha * 255)).astype(np.uint8) result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha return result.clip(0, 255).astype(np.uint8), rgba @torch.inference_mode() def process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source): clear_memory() bg_source = BGSource(bg_source) if bg_source == BGSource.UPLOAD: pass elif bg_source == BGSource.UPLOAD_FLIP: input_bg = np.fliplr(input_bg) elif bg_source == BGSource.GREY: input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64 elif bg_source == BGSource.LEFT: gradient = np.linspace(224, 32, image_width) image = np.tile(gradient, (image_height, 1)) input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) elif bg_source == BGSource.RIGHT: gradient = np.linspace(32, 224, image_width) image = np.tile(gradient, (image_height, 1)) input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) elif bg_source == BGSource.TOP: gradient = np.linspace(224, 32, image_height)[:, None] image = np.tile(gradient, (1, image_width)) input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) elif bg_source == BGSource.BOTTOM: gradient = np.linspace(32, 224, image_height)[:, None] image = np.tile(gradient, (1, image_width)) input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) else: raise 'Wrong background source!' rng = torch.Generator(device=device).manual_seed(seed) fg = resize_and_center_crop(input_fg, image_width, image_height) bg = resize_and_center_crop(input_bg, image_width, image_height) concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype) concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1) conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt) latents = t2i_pipe( prompt_embeds=conds, negative_prompt_embeds=unconds, width=image_width, height=image_height, num_inference_steps=steps, num_images_per_prompt=num_samples, generator=rng, output_type='latent', guidance_scale=cfg, cross_attention_kwargs={'concat_conds': concat_conds}, ).images.to(vae.dtype) / vae.config.scaling_factor pixels = vae.decode(latents).sample pixels = pytorch2numpy(pixels) pixels = [resize_without_crop( image=p, target_width=int(round(image_width * highres_scale / 64.0) * 64), target_height=int(round(image_height * highres_scale / 64.0) * 64)) for p in pixels] pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype) latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor latents = latents.to(device=unet.device, dtype=unet.dtype) image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8 fg = resize_and_center_crop(input_fg, image_width, image_height) bg = resize_and_center_crop(input_bg, image_width, image_height) concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype) concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1) latents = i2i_pipe( image=latents, strength=highres_denoise, prompt_embeds=conds, negative_prompt_embeds=unconds, width=image_width, height=image_height, num_inference_steps=int(round(steps / highres_denoise)), num_images_per_prompt=num_samples, generator=rng, output_type='latent', guidance_scale=cfg, cross_attention_kwargs={'concat_conds': concat_conds}, ).images.to(vae.dtype) / vae.config.scaling_factor pixels = vae.decode(latents).sample pixels = pytorch2numpy(pixels, quant=False) clear_memory() return pixels, [fg, bg] # Add save function def save_images(images, prefix="relight"): # Create output directory if it doesn't exist output_dir = Path("outputs") output_dir.mkdir(exist_ok=True) # Create timestamp for unique filenames timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") saved_paths = [] for i, img in enumerate(images): if isinstance(img, np.ndarray): # Convert to PIL Image if numpy array img = Image.fromarray(img) # Create filename with timestamp filename = f"{prefix}_{timestamp}_{i+1}.png" filepath = output_dir / filename # Save image img.save(filepath) # print(f"Saved {len(saved_paths)} images to {output_dir}") return saved_paths # Modify process_relight to save images @torch.inference_mode() def process_relight(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source): input_fg, matting = run_rmbg(input_fg) # show input_fg in a new image input_fg_img = Image.fromarray(input_fg) results, extra_images = process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source) results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results] final_results = results + extra_images # Save the generated images save_images(results, prefix="relight") return results # Modify process_normal to save images @torch.inference_mode() def process_normal(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source): input_fg, matting = run_rmbg(input_fg, sigma=16) print('left ...') left = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.LEFT.value)[0][0] print('right ...') right = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.RIGHT.value)[0][0] print('bottom ...') bottom = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.BOTTOM.value)[0][0] print('top ...') top = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.TOP.value)[0][0] inner_results = [left * 2.0 - 1.0, right * 2.0 - 1.0, bottom * 2.0 - 1.0, top * 2.0 - 1.0] ambient = (left + right + bottom + top) / 4.0 h, w, _ = ambient.shape matting = resize_and_center_crop((matting[..., 0] * 255.0).clip(0, 255).astype(np.uint8), w, h).astype(np.float32)[..., None] / 255.0 def safa_divide(a, b): e = 1e-5 return ((a + e) / (b + e)) - 1.0 left = safa_divide(left, ambient) right = safa_divide(right, ambient) bottom = safa_divide(bottom, ambient) top = safa_divide(top, ambient) u = (right - left) * 0.5 v = (top - bottom) * 0.5 sigma = 10.0 u = np.mean(u, axis=2) v = np.mean(v, axis=2) h = (1.0 - u ** 2.0 - v ** 2.0).clip(0, 1e5) ** (0.5 * sigma) z = np.zeros_like(h) normal = np.stack([u, v, h], axis=2) normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 normal = normal * matting + np.stack([z, z, 1 - z], axis=2) * (1 - matting) results = [normal, left, right, bottom, top] + inner_results results = [(x * 127.5 + 127.5).clip(0, 255).astype(np.uint8) for x in results] # Save the generated images save_images(results, prefix="normal") return results quick_prompts = [ 'modern sofa in living room', 'elegant dining table with chairs', 'luxurious bed in bedroom, cinematic lighting', 'minimalist office desk, natural lighting', 'vintage wooden cabinet, warm lighting', 'contemporary bookshelf, ambient lighting', 'designer armchair, dramatic lighting', 'modern kitchen island, bright lighting', ] quick_prompts = [[x] for x in quick_prompts] class BGSource(Enum): UPLOAD = "Use Background Image" UPLOAD_FLIP = "Use Flipped Background Image" LEFT = "Left Light" RIGHT = "Right Light" TOP = "Top Light" BOTTOM = "Bottom Light" GREY = "Ambient" class MaskMover: def __init__(self): self.extracted_fg = None self.original_fg = None # Store original foreground def set_extracted_fg(self, fg_image): """Store the extracted foreground with alpha channel""" self.extracted_fg = fg_image.copy() self.original_fg = fg_image.copy() # Keep original return fg_image def create_composite(self, background, x_pos, y_pos, scale=1.0): """Create composite with foreground at specified position""" if self.original_fg is None or background is None: return background # Convert inputs to PIL Images if isinstance(background, np.ndarray): bg = Image.fromarray(background) else: bg = background if isinstance(self.original_fg, np.ndarray): fg = Image.fromarray(self.original_fg) else: fg = self.original_fg # Scale the foreground size new_width = int(fg.width * scale) new_height = int(fg.height * scale) fg = fg.resize((new_width, new_height), Image.LANCZOS) # Center the scaled foreground at the position x = int(x_pos - new_width / 2) y = int(y_pos - new_height / 2) # Create composite result = bg.copy() if fg.mode == 'RGBA': # If foreground has alpha channel result.paste(fg, (x, y), fg.split()[3]) # Use alpha channel as mask else: result.paste(fg, (x, y)) return np.array(result) def get_depth(image): if image is None: return None # Convert from PIL/gradio format to cv2 raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # Get depth map depth = model.infer_image(raw_img) # HxW raw depth map # Normalize depth for visualization depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8) # Convert to RGB for display depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO) depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB) return Image.fromarray(depth_colored) # def find_objects(image_input): # detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) # if len(detections) == 0: # gr.Info("No objects detected.") # return None # return Image.fromarray(detections.mask[0].astype("uint8") * 255) block = gr.Blocks().queue() with block: mask_mover = MaskMover() with gr.Row(): gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)") gr.Markdown("💾 Generated images are automatically saved to 'outputs' folder") with gr.Row(): with gr.Column(): # Step 1: Input and Extract with gr.Group(): gr.Markdown("### Step 1: Extract Foreground") input_image = gr.Image(type="numpy", label="Input Image", height=480) input_text = gr.Textbox(label="Describe target object") find_objects_button = gr.Button(value="Find Objects") extract_button = gr.Button(value="Remove Background") extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480) # Step 2: Background and Position with gr.Group(): gr.Markdown("### Step 2: Position on Background") input_bg = gr.Image(type="numpy", label="Background Image", height=480) with gr.Row(): x_slider = gr.Slider( minimum=0, maximum=1000, label="X Position", value=500, visible=False ) y_slider = gr.Slider( minimum=0, maximum=1000, label="Y Position", value=500, visible=False ) fg_scale_slider = gr.Slider( label="Foreground Scale", minimum=0.01, maximum=3.0, value=1.0, step=0.01 ) get_depth_button = gr.Button(value="Get Depth") depth_image = gr.Image(type="numpy", label="Depth Image", height=480) editor = gr.ImageEditor( type="numpy", label="Position Foreground", height=480, visible=False ) # Step 3: Relighting Options with gr.Group(): gr.Markdown("### Step 3: Relighting Settings") prompt = gr.Textbox(label="Prompt") bg_source = gr.Radio( choices=[e.value for e in BGSource], value=BGSource.UPLOAD.value, label="Background Source", type='value' ) example_prompts = gr.Dataset( samples=quick_prompts, label='Prompt Quick List', components=[prompt] ) # bg_gallery = gr.Gallery( # height=450, # label='Background Quick List', # value=db_examples.bg_samples, # columns=5, # allow_preview=False # ) relight_button = gr.Button(value="Relight") # Additional settings with gr.Group(): with gr.Row(): num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) seed = gr.Number(label="Seed", value=12345, precision=0) with gr.Row(): image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64) image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64) with gr.Accordion("Advanced options", open=False): steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01) highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=2.0, value=1.2, step=0.01) highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01) a_prompt = gr.Textbox(label="Added Prompt", value='best quality') n_prompt = gr.Textbox( label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality' ) normal_button = gr.Button(value="Compute Normal (4x Slower)") with gr.Column(): result_gallery = gr.Image(height=832, label='Outputs') # Event handlers def extract_foreground(image): if image is None: return None, gr.update(visible=True), gr.update(visible=True) result, rgba = run_rmbg(image) mask_mover.set_extracted_fg(rgba) return result, gr.update(visible=True), gr.update(visible=True) original_bg = None extract_button.click( fn=extract_foreground, inputs=[input_image], outputs=[extracted_fg, x_slider, y_slider] ) find_objects_button.click( fn=process_image, inputs=[input_image, input_text], outputs=[extracted_fg] ) get_depth_button.click( fn=get_depth, inputs=[input_bg], outputs=[depth_image] ) def update_position(background, x_pos, y_pos, scale): """Update composite when position changes""" global original_bg if background is None: return None if original_bg is None: original_bg = background.copy() # Convert string values to float x_pos = float(x_pos) y_pos = float(y_pos) scale = float(scale) return mask_mover.create_composite(original_bg, x_pos, y_pos, scale) x_slider.change( fn=update_position, inputs=[input_bg, x_slider, y_slider, fg_scale_slider], outputs=[input_bg] ) y_slider.change( fn=update_position, inputs=[input_bg, x_slider, y_slider, fg_scale_slider], outputs=[input_bg] ) fg_scale_slider.change( fn=update_position, inputs=[input_bg, x_slider, y_slider, fg_scale_slider], outputs=[input_bg] ) # Update inputs list to include fg_scale_slider ips = [input_bg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source, x_slider, y_slider, fg_scale_slider] # Added fg_scale_slider def process_relight_with_position(*args): if mask_mover.extracted_fg is None: gr.Warning("Please extract foreground first") return None background = args[1] # Get background image x_pos = float(args[-3]) # x_slider value y_pos = float(args[-2]) # y_slider value scale = float(args[-1]) # fg_scale_slider value # Get original foreground size after scaling fg = Image.fromarray(mask_mover.original_fg) new_width = int(fg.width * scale) new_height = int(fg.height * scale) # Calculate crop region around foreground position crop_x = int(x_pos - new_width/2) crop_y = int(y_pos - new_height/2) crop_width = new_width crop_height = new_height # Add padding for context (20% extra on each side) padding = 0.2 crop_x = int(crop_x - crop_width * padding) crop_y = int(crop_y - crop_height * padding) crop_width = int(crop_width * (1 + 2 * padding)) crop_height = int(crop_height * (1 + 2 * padding)) # Ensure crop dimensions are multiples of 8 crop_width = ((crop_width + 7) // 8) * 8 crop_height = ((crop_height + 7) // 8) * 8 # Ensure crop region is within image bounds bg_height, bg_width = background.shape[:2] crop_x = max(0, min(crop_x, bg_width - crop_width)) crop_y = max(0, min(crop_y, bg_height - crop_height)) # Get actual crop dimensions after boundary check crop_width = min(crop_width, bg_width - crop_x) crop_height = min(crop_height, bg_height - crop_y) # Ensure dimensions are multiples of 8 again crop_width = (crop_width // 8) * 8 crop_height = (crop_height // 8) * 8 # Crop region from background crop_region = background[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] # Create composite in cropped region fg_local_x = int(new_width/2 + crop_width*padding) fg_local_y = int(new_height/2 + crop_height*padding) cropped_composite = mask_mover.create_composite(crop_region, fg_local_x, fg_local_y, scale) # Process the cropped region crop_args = list(args) crop_args[0] = cropped_composite crop_args[1] = crop_region crop_args[3] = crop_width crop_args[4] = crop_height crop_args = crop_args[:-3] # Remove position and scale arguments # Get relit result relit_crop = process_relight(*crop_args)[0] # Resize relit result to match crop dimensions if needed if relit_crop.shape[:2] != (crop_height, crop_width): relit_crop = resize_without_crop(relit_crop, crop_width, crop_height) # Place relit crop back into original background result = background.copy() result[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] = relit_crop return result # Update button click events with new inputs list relight_button.click( fn=process_relight_with_position, inputs=ips, outputs=[result_gallery] ) # Update normal_button to use same argument handling def process_normal_with_position(*args): if mask_mover.extracted_fg is None: gr.Warning("Please extract foreground first") return None background = args[1] x_pos = float(args[-3]) # x_slider value y_pos = float(args[-2]) # y_slider value scale = float(args[-1]) # fg_scale_slider value # Get original foreground size after scaling fg = Image.fromarray(mask_mover.original_fg) new_width = int(fg.width * scale) new_height = int(fg.height * scale) # Calculate crop region around foreground position crop_x = int(x_pos - new_width/2) crop_y = int(y_pos - new_height/2) crop_width = new_width crop_height = new_height # Add padding for context (20% extra on each side) padding = 0.2 crop_x = int(crop_x - crop_width * padding) crop_y = int(crop_y - crop_height * padding) crop_width = int(crop_width * (1 + 2 * padding)) crop_height = int(crop_height * (1 + 2 * padding)) # Ensure crop dimensions are multiples of 8 crop_width = ((crop_width + 7) // 8) * 8 crop_height = ((crop_height + 7) // 8) * 8 # Ensure crop region is within image bounds bg_height, bg_width = background.shape[:2] crop_x = max(0, min(crop_x, bg_width - crop_width)) crop_y = max(0, min(crop_y, bg_height - crop_height)) # Crop region from background crop_region = background[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] # Create composite in cropped region fg_local_x = int(new_width/2 + crop_width*padding) fg_local_y = int(new_height/2 + crop_height*padding) cropped_composite = mask_mover.create_composite(crop_region, fg_local_x, fg_local_y, scale) # Process the cropped region crop_args = list(args) crop_args[0] = cropped_composite crop_args[1] = crop_region crop_args[3] = crop_width crop_args[4] = crop_height crop_args = crop_args[:-3] # Get processed result processed_crop = process_normal(*crop_args) # Place processed crop back into original background result = background.copy() result[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] = processed_crop return result[0] if result else None normal_button.click( fn=process_normal_with_position, inputs=ips, outputs=[result_gallery] ) example_prompts.click( fn=lambda x: x[0], inputs=example_prompts, outputs=prompt, show_progress=False, queue=False ) # def bg_gallery_selected(gal, evt: gr.SelectData): # return gal[evt.index]['name'] # bg_gallery.select( # fn=bg_gallery_selected, # inputs=bg_gallery, # outputs=input_bg # ) block.launch(server_name='0.0.0.0')