Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import spaces | |
from diffusers import FluxInpaintPipeline | |
from PIL import Image #, ImageFile | |
import io | |
import numpy as np | |
# Enable loading of truncated images | |
# ImageFile.LOAD_TRUNCATED_IMAGES = True | |
# Initialize the pipeline | |
pipe = FluxInpaintPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.to("cuda") | |
pipe.load_lora_weights( | |
"ali-vilab/In-Context-LoRA", | |
weight_name="visual-identity-design.safetensors" | |
) | |
def safe_open_image(image): | |
"""Safely open and validate image""" | |
try: | |
if isinstance(image, np.ndarray): | |
# Convert numpy array to PIL Image | |
image = Image.fromarray(image) | |
elif isinstance(image, bytes): | |
# Handle bytes input | |
image = Image.open(io.BytesIO(image)) | |
# Ensure the image is in RGB mode | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
return image | |
except Exception as e: | |
raise ValueError(f"Error processing input image: {str(e)}") | |
def square_center_crop(img, target_size=768): | |
"""Improved center crop with additional validation""" | |
try: | |
img = safe_open_image(img) | |
# Ensure minimum size | |
if img.size[0] < 64 or img.size[1] < 64: | |
raise ValueError("Image is too small. Minimum size is 64x64 pixels.") | |
width, height = img.size | |
crop_size = min(width, height) | |
# Calculate crop coordinates | |
left = max(0, (width - crop_size) // 2) | |
top = max(0, (height - crop_size) // 2) | |
right = min(width, left + crop_size) | |
bottom = min(height, top + crop_size) | |
img_cropped = img.crop((left, top, right, bottom)) | |
# Use high-quality resizing | |
return img_cropped.resize( | |
(target_size, target_size), | |
Image.Resampling.LANCZOS, | |
reducing_gap=3.0 | |
) | |
except Exception as e: | |
raise ValueError(f"Error during image cropping: {str(e)}") | |
def duplicate_horizontally(img): | |
"""Improved horizontal duplication with validation""" | |
try: | |
width, height = img.size | |
if width != height: | |
raise ValueError(f"Input image must be square, got {width}x{height}") | |
# Create new image with RGB mode explicitly | |
new_image = Image.new('RGB', (width * 2, height)) | |
# Ensure the source image is in RGB mode | |
if img.mode != 'RGB': | |
img = img.convert('RGB') | |
new_image.paste(img, (0, 0)) | |
new_image.paste(img, (width, 0)) | |
return new_image | |
except Exception as e: | |
raise ValueError(f"Error during image duplication: {str(e)}") | |
def safe_crop_output(img): | |
"""Safely crop the output image""" | |
try: | |
width, height = img.size | |
half_width = width // 2 | |
return img.crop((half_width, 0, width, height)) | |
except Exception as e: | |
raise ValueError(f"Error cropping output image: {str(e)}") | |
# Load the mask image with error handling | |
try: | |
mask = Image.open("mask_square.png") | |
if mask.mode != 'RGB': | |
mask = mask.convert('RGB') | |
except Exception as e: | |
raise RuntimeError(f"Error loading mask image: {str(e)}") | |
def generate(image, prompt_user, progress=gr.Progress(track_tqdm=True)): | |
"""Improved generation function with proper error handling""" | |
try: | |
if image is None: | |
raise ValueError("No input image provided") | |
if not prompt_user or prompt_user.strip() == "": | |
raise ValueError("Please provide a prompt") | |
prompt_structure = "The two-panel image showcases the logo of a brand, [LEFT] the left panel is showing the logo [RIGHT] the right panel has this logo applied to " | |
prompt = prompt_structure + prompt_user | |
# Process input image | |
cropped_image = square_center_crop(image) | |
logo_dupli = duplicate_horizontally(cropped_image) | |
# Generate output | |
out = pipe( | |
prompt=prompt, | |
image=logo_dupli, | |
mask_image=mask, | |
guidance_scale=6, | |
height=768, | |
width=1536, | |
num_inference_steps=28, | |
max_sequence_length=256, | |
strength=1 | |
).images[0] | |
# First yield for progress | |
yield None, out | |
# Process and return final output | |
image_2 = safe_crop_output(out) | |
yield image_2, out | |
except Exception as e: | |
error_message = f"Error during generation: {str(e)}" | |
print(error_message) # For logging | |
raise gr.Error(error_message) | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Logo in Context") | |
gr.Markdown("### In-Context LoRA + Image-to-Image, apply your logo to anything") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Upload Logo Image", | |
type="pil", | |
height=384 | |
) | |
prompt_input = gr.Textbox( | |
label="Where should the logo be applied?", | |
placeholder="e.g., a coffee cup on a wooden table", | |
lines=2 | |
) | |
generate_btn = gr.Button("Generate Application", variant="primary") | |
with gr.Column(): | |
output_image = gr.Image( | |
label="Generated Application", | |
type="pil" | |
) | |
output_side = gr.Image( | |
label="Side by side", | |
type="pil" | |
) | |
with gr.Row(): | |
gr.Markdown(""" | |
### Instructions: | |
1. Upload a logo image (preferably square) | |
2. Describe where you'd like to see the logo applied | |
3. Click 'Generate Application' and wait for the result | |
Note: The generation process might take a few moments. | |
""") | |
# Set up the click event with error handling | |
generate_btn.click( | |
fn=generate, | |
inputs=[input_image, prompt_input], | |
outputs=[output_image, output_side], | |
api_name="generate" | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
demo.launch() |