import gradio as gr import torch import spaces from diffusers import FluxInpaintPipeline from PIL import Image #, ImageFile import io import numpy as np # Enable loading of truncated images # ImageFile.LOAD_TRUNCATED_IMAGES = True # Initialize the pipeline pipe = FluxInpaintPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 ) pipe.to("cuda") pipe.load_lora_weights( "ali-vilab/In-Context-LoRA", weight_name="visual-identity-design.safetensors" ) def safe_open_image(image): """Safely open and validate image""" try: if isinstance(image, np.ndarray): # Convert numpy array to PIL Image image = Image.fromarray(image) elif isinstance(image, bytes): # Handle bytes input image = Image.open(io.BytesIO(image)) # Ensure the image is in RGB mode if image.mode != 'RGB': image = image.convert('RGB') return image except Exception as e: raise ValueError(f"Error processing input image: {str(e)}") def square_center_crop(img, target_size=768): """Improved center crop with additional validation""" try: img = safe_open_image(img) # Ensure minimum size if img.size[0] < 64 or img.size[1] < 64: raise ValueError("Image is too small. Minimum size is 64x64 pixels.") width, height = img.size crop_size = min(width, height) # Calculate crop coordinates left = max(0, (width - crop_size) // 2) top = max(0, (height - crop_size) // 2) right = min(width, left + crop_size) bottom = min(height, top + crop_size) img_cropped = img.crop((left, top, right, bottom)) # Use high-quality resizing return img_cropped.resize( (target_size, target_size), Image.Resampling.LANCZOS, reducing_gap=3.0 ) except Exception as e: raise ValueError(f"Error during image cropping: {str(e)}") def duplicate_horizontally(img): """Improved horizontal duplication with validation""" try: width, height = img.size if width != height: raise ValueError(f"Input image must be square, got {width}x{height}") # Create new image with RGB mode explicitly new_image = Image.new('RGB', (width * 2, height)) # Ensure the source image is in RGB mode if img.mode != 'RGB': img = img.convert('RGB') new_image.paste(img, (0, 0)) new_image.paste(img, (width, 0)) return new_image except Exception as e: raise ValueError(f"Error during image duplication: {str(e)}") def safe_crop_output(img): """Safely crop the output image""" try: width, height = img.size half_width = width // 2 return img.crop((half_width, 0, width, height)) except Exception as e: raise ValueError(f"Error cropping output image: {str(e)}") # Load the mask image with error handling try: mask = Image.open("mask_square.png") if mask.mode != 'RGB': mask = mask.convert('RGB') except Exception as e: raise RuntimeError(f"Error loading mask image: {str(e)}") @spaces.GPU def generate(image, prompt_user, progress=gr.Progress(track_tqdm=True)): """Improved generation function with proper error handling""" try: if image is None: raise ValueError("No input image provided") if not prompt_user or prompt_user.strip() == "": raise ValueError("Please provide a prompt") prompt_structure = "The two-panel image showcases the logo of a brand, [LEFT] the left panel is showing the logo [RIGHT] the right panel has this logo applied to " prompt = prompt_structure + prompt_user # Process input image try: cropped_image = square_center_crop(image) except Exception as e: error_message = f"Error during cropping: {str(e)}" print(error_message) # For logging raise gr.Error(error_message) print("Size after cropping", cropped_image.size) try: logo_dupli = duplicate_horizontally(cropped_image) except Exception as e: error_message = f"Error during duplication: {str(e)}" print(error_message) # For logging raise gr.Error(error_message) print("just before getting into pipe") # Generate output out = pipe( prompt=prompt, image=logo_dupli, mask_image=mask, guidance_scale=6, height=768, width=1536, num_inference_steps=28, max_sequence_length=256, strength=1 ).images[0] # First yield for progress yield None, out # Process and return final output image_2 = safe_crop_output(out) yield image_2, out except Exception as e: error_message = f"Error during generation: {str(e)}" print(error_message) # For logging raise gr.Error(error_message) # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Logo in Context") gr.Markdown("### In-Context LoRA + Image-to-Image, apply your logo to anything") with gr.Row(): with gr.Column(): input_image = gr.Image( label="Upload Logo Image", type="pil", height=384 ) prompt_input = gr.Textbox( label="Where should the logo be applied?", placeholder="e.g., a coffee cup on a wooden table", lines=2 ) generate_btn = gr.Button("Generate Application", variant="primary") with gr.Column(): output_image = gr.Image( label="Generated Application", type="pil" ) output_side = gr.Image( label="Side by side", type="pil" ) with gr.Row(): gr.Markdown(""" ### Instructions: 1. Upload a logo image (preferably square) 2. Describe where you'd like to see the logo applied 3. Click 'Generate Application' and wait for the result Note: The generation process might take a few moments. """) # Set up the click event with error handling generate_btn.click( fn=generate, inputs=[input_image, prompt_input], outputs=[output_image, output_side], api_name="generate" ) # Launch the interface if __name__ == "__main__": demo.launch()