multimodalart's picture
Update app.py
777ad8e verified
raw
history blame
6.34 kB
import gradio as gr
import torch
import spaces
from diffusers import FluxInpaintPipeline
from PIL import Image #, ImageFile
import io
import numpy as np
# Enable loading of truncated images
# ImageFile.LOAD_TRUNCATED_IMAGES = True
# Initialize the pipeline
pipe = FluxInpaintPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
pipe.load_lora_weights(
"ali-vilab/In-Context-LoRA",
weight_name="visual-identity-design.safetensors"
)
def safe_open_image(image):
"""Safely open and validate image"""
try:
if isinstance(image, np.ndarray):
# Convert numpy array to PIL Image
image = Image.fromarray(image)
elif isinstance(image, bytes):
# Handle bytes input
image = Image.open(io.BytesIO(image))
# Ensure the image is in RGB mode
if image.mode != 'RGB':
image = image.convert('RGB')
return image
except Exception as e:
raise ValueError(f"Error processing input image: {str(e)}")
def square_center_crop(img, target_size=768):
"""Improved center crop with additional validation"""
try:
img = safe_open_image(img)
# Ensure minimum size
if img.size[0] < 64 or img.size[1] < 64:
raise ValueError("Image is too small. Minimum size is 64x64 pixels.")
width, height = img.size
crop_size = min(width, height)
# Calculate crop coordinates
left = max(0, (width - crop_size) // 2)
top = max(0, (height - crop_size) // 2)
right = min(width, left + crop_size)
bottom = min(height, top + crop_size)
img_cropped = img.crop((left, top, right, bottom))
# Use high-quality resizing
return img_cropped.resize(
(target_size, target_size),
Image.Resampling.LANCZOS,
reducing_gap=3.0
)
except Exception as e:
raise ValueError(f"Error during image cropping: {str(e)}")
def duplicate_horizontally(img):
"""Improved horizontal duplication with validation"""
try:
width, height = img.size
if width != height:
raise ValueError(f"Input image must be square, got {width}x{height}")
# Create new image with RGB mode explicitly
new_image = Image.new('RGB', (width * 2, height))
# Ensure the source image is in RGB mode
if img.mode != 'RGB':
img = img.convert('RGB')
new_image.paste(img, (0, 0))
new_image.paste(img, (width, 0))
return new_image
except Exception as e:
raise ValueError(f"Error during image duplication: {str(e)}")
def safe_crop_output(img):
"""Safely crop the output image"""
try:
width, height = img.size
half_width = width // 2
return img.crop((half_width, 0, width, height))
except Exception as e:
raise ValueError(f"Error cropping output image: {str(e)}")
# Load the mask image with error handling
try:
mask = Image.open("mask_square.png")
if mask.mode != 'RGB':
mask = mask.convert('RGB')
except Exception as e:
raise RuntimeError(f"Error loading mask image: {str(e)}")
@spaces.GPU
def generate(image, prompt_user, progress=gr.Progress(track_tqdm=True)):
"""Improved generation function with proper error handling"""
try:
if image is None:
raise ValueError("No input image provided")
if not prompt_user or prompt_user.strip() == "":
raise ValueError("Please provide a prompt")
prompt_structure = "The two-panel image showcases the logo of a brand, [LEFT] the left panel is showing the logo [RIGHT] the right panel has this logo applied to "
prompt = prompt_structure + prompt_user
# Process input image
cropped_image = square_center_crop(image)
logo_dupli = duplicate_horizontally(cropped_image)
# Generate output
out = pipe(
prompt=prompt,
image=logo_dupli,
mask_image=mask,
guidance_scale=6,
height=768,
width=1536,
num_inference_steps=28,
max_sequence_length=256,
strength=1
).images[0]
# First yield for progress
yield None, out
# Process and return final output
image_2 = safe_crop_output(out)
yield image_2, out
except Exception as e:
error_message = f"Error during generation: {str(e)}"
print(error_message) # For logging
raise gr.Error(error_message)
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Logo in Context")
gr.Markdown("### In-Context LoRA + Image-to-Image, apply your logo to anything")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Logo Image",
type="pil",
height=384
)
prompt_input = gr.Textbox(
label="Where should the logo be applied?",
placeholder="e.g., a coffee cup on a wooden table",
lines=2
)
generate_btn = gr.Button("Generate Application", variant="primary")
with gr.Column():
output_image = gr.Image(
label="Generated Application",
type="pil"
)
output_side = gr.Image(
label="Side by side",
type="pil"
)
with gr.Row():
gr.Markdown("""
### Instructions:
1. Upload a logo image (preferably square)
2. Describe where you'd like to see the logo applied
3. Click 'Generate Application' and wait for the result
Note: The generation process might take a few moments.
""")
# Set up the click event with error handling
generate_btn.click(
fn=generate,
inputs=[input_image, prompt_input],
outputs=[output_image, output_side],
api_name="generate"
)
# Launch the interface
if __name__ == "__main__":
demo.launch()