multimodalart HF staff commited on
Commit
ce1bf6b
·
verified ·
1 Parent(s): 6c51e38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -51
app.py CHANGED
@@ -3,7 +3,10 @@ import torch
3
  import spaces
4
  from diffusers import FluxInpaintPipeline
5
  from PIL import Image, ImageFile
 
 
6
 
 
7
  ImageFile.LOAD_TRUNCATED_IMAGES = True
8
 
9
  # Initialize the pipeline
@@ -17,60 +20,134 @@ pipe.load_lora_weights(
17
  weight_name="visual-identity-design.safetensors"
18
  )
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def square_center_crop(img, target_size=768):
21
- if img.mode in ('RGBA', 'P'):
22
- img = img.convert('RGB')
23
-
24
- width, height = img.size
25
- crop_size = min(width, height)
26
-
27
- left = (width - crop_size) // 2
28
- top = (height - crop_size) // 2
29
- right = left + crop_size
30
- bottom = top + crop_size
31
-
32
- img_cropped = img.crop((left, top, right, bottom))
33
- return img_cropped.resize((target_size, target_size), Image.Resampling.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def duplicate_horizontally(img):
36
- width, height = img.size
37
- if width != height:
38
- raise ValueError(f"Input image must be square, got {width}x{height}")
39
-
40
- new_image = Image.new('RGB', (width * 2, height))
41
- new_image.paste(img, (0, 0))
42
- new_image.paste(img, (width, 0))
43
- return new_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Load the mask image
46
- mask = Image.open("mask_square.png")
 
 
 
 
 
47
 
48
  @spaces.GPU
49
  def generate(image, prompt_user, progress=gr.Progress(track_tqdm=True)):
50
- prompt_structure = "The two-panel image showcases the logo of a brand, [LEFT] the left panel is showing the logo [RIGHT] the right panel has this logo applied to "
51
- prompt = prompt_structure + prompt_user
52
-
53
- cropped_image = square_center_crop(image)
54
- logo_dupli = duplicate_horizontally(cropped_image)
55
-
56
- out = pipe(
57
- prompt=prompt,
58
- image=logo_dupli,
59
- mask_image=mask,
60
- guidance_scale=6,
61
- height=768,
62
- width=1536,
63
- num_inference_steps=28,
64
- max_sequence_length=256,
65
- strength=1
66
- ).images[0]
67
-
68
- yield None, out
69
- width, height = out.size
70
- half_width = width // 2
71
- image_2 = out.crop((half_width, 0, width, height))
72
- yield image_2, out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
74
  with gr.Blocks() as demo:
75
  gr.Markdown("# Logo in Context")
76
  gr.Markdown("### In-Context LoRA + Image-to-Image, apply your logo to anything")
@@ -80,7 +157,8 @@ with gr.Blocks() as demo:
80
  input_image = gr.Image(
81
  label="Upload Logo Image",
82
  type="pil",
83
- height=384
 
84
  )
85
  prompt_input = gr.Textbox(
86
  label="Where should the logo be applied?",
@@ -90,8 +168,15 @@ with gr.Blocks() as demo:
90
  generate_btn = gr.Button("Generate Application", variant="primary")
91
 
92
  with gr.Column():
93
- output_image = gr.Image(label="Generated Application")
94
- output_side = gr.Image(label="Side by side")
 
 
 
 
 
 
 
95
  with gr.Row():
96
  gr.Markdown("""
97
  ### Instructions:
@@ -102,11 +187,12 @@ with gr.Blocks() as demo:
102
  Note: The generation process might take a few moments.
103
  """)
104
 
105
- # Set up the click event
106
  generate_btn.click(
107
  fn=generate,
108
  inputs=[input_image, prompt_input],
109
- outputs=[output_image, output_side]
 
110
  )
111
 
112
  # Launch the interface
 
3
  import spaces
4
  from diffusers import FluxInpaintPipeline
5
  from PIL import Image, ImageFile
6
+ import io
7
+ import numpy as np
8
 
9
+ # Enable loading of truncated images
10
  ImageFile.LOAD_TRUNCATED_IMAGES = True
11
 
12
  # Initialize the pipeline
 
20
  weight_name="visual-identity-design.safetensors"
21
  )
22
 
23
+ def safe_open_image(image):
24
+ """Safely open and validate image"""
25
+ try:
26
+ if isinstance(image, np.ndarray):
27
+ # Convert numpy array to PIL Image
28
+ image = Image.fromarray(image)
29
+ elif isinstance(image, bytes):
30
+ # Handle bytes input
31
+ image = Image.open(io.BytesIO(image))
32
+
33
+ # Ensure the image is in RGB mode
34
+ if image.mode != 'RGB':
35
+ image = image.convert('RGB')
36
+
37
+ return image
38
+ except Exception as e:
39
+ raise ValueError(f"Error processing input image: {str(e)}")
40
+
41
  def square_center_crop(img, target_size=768):
42
+ """Improved center crop with additional validation"""
43
+ try:
44
+ img = safe_open_image(img)
45
+
46
+ # Ensure minimum size
47
+ if img.size[0] < 64 or img.size[1] < 64:
48
+ raise ValueError("Image is too small. Minimum size is 64x64 pixels.")
49
+
50
+ width, height = img.size
51
+ crop_size = min(width, height)
52
+
53
+ # Calculate crop coordinates
54
+ left = max(0, (width - crop_size) // 2)
55
+ top = max(0, (height - crop_size) // 2)
56
+ right = min(width, left + crop_size)
57
+ bottom = min(height, top + crop_size)
58
+
59
+ img_cropped = img.crop((left, top, right, bottom))
60
+
61
+ # Use high-quality resizing
62
+ return img_cropped.resize(
63
+ (target_size, target_size),
64
+ Image.Resampling.LANCZOS,
65
+ reducing_gap=3.0
66
+ )
67
+ except Exception as e:
68
+ raise ValueError(f"Error during image cropping: {str(e)}")
69
 
70
  def duplicate_horizontally(img):
71
+ """Improved horizontal duplication with validation"""
72
+ try:
73
+ width, height = img.size
74
+ if width != height:
75
+ raise ValueError(f"Input image must be square, got {width}x{height}")
76
+
77
+ # Create new image with RGB mode explicitly
78
+ new_image = Image.new('RGB', (width * 2, height))
79
+
80
+ # Ensure the source image is in RGB mode
81
+ if img.mode != 'RGB':
82
+ img = img.convert('RGB')
83
+
84
+ new_image.paste(img, (0, 0))
85
+ new_image.paste(img, (width, 0))
86
+
87
+ return new_image
88
+ except Exception as e:
89
+ raise ValueError(f"Error during image duplication: {str(e)}")
90
+
91
+ def safe_crop_output(img):
92
+ """Safely crop the output image"""
93
+ try:
94
+ width, height = img.size
95
+ half_width = width // 2
96
+ return img.crop((half_width, 0, width, height))
97
+ except Exception as e:
98
+ raise ValueError(f"Error cropping output image: {str(e)}")
99
 
100
+ # Load the mask image with error handling
101
+ try:
102
+ mask = Image.open("mask_square.png")
103
+ if mask.mode != 'RGB':
104
+ mask = mask.convert('RGB')
105
+ except Exception as e:
106
+ raise RuntimeError(f"Error loading mask image: {str(e)}")
107
 
108
  @spaces.GPU
109
  def generate(image, prompt_user, progress=gr.Progress(track_tqdm=True)):
110
+ """Improved generation function with proper error handling"""
111
+ try:
112
+ if image is None:
113
+ raise ValueError("No input image provided")
114
+
115
+ if not prompt_user or prompt_user.strip() == "":
116
+ raise ValueError("Please provide a prompt")
117
+
118
+ prompt_structure = "The two-panel image showcases the logo of a brand, [LEFT] the left panel is showing the logo [RIGHT] the right panel has this logo applied to "
119
+ prompt = prompt_structure + prompt_user
120
+
121
+ # Process input image
122
+ cropped_image = square_center_crop(image)
123
+ logo_dupli = duplicate_horizontally(cropped_image)
124
+
125
+ # Generate output
126
+ out = pipe(
127
+ prompt=prompt,
128
+ image=logo_dupli,
129
+ mask_image=mask,
130
+ guidance_scale=6,
131
+ height=768,
132
+ width=1536,
133
+ num_inference_steps=28,
134
+ max_sequence_length=256,
135
+ strength=1
136
+ ).images[0]
137
+
138
+ # First yield for progress
139
+ yield None, out
140
+
141
+ # Process and return final output
142
+ image_2 = safe_crop_output(out)
143
+ yield image_2, out
144
+
145
+ except Exception as e:
146
+ error_message = f"Error during generation: {str(e)}"
147
+ print(error_message) # For logging
148
+ raise gr.Error(error_message)
149
 
150
+ # Create the Gradio interface
151
  with gr.Blocks() as demo:
152
  gr.Markdown("# Logo in Context")
153
  gr.Markdown("### In-Context LoRA + Image-to-Image, apply your logo to anything")
 
157
  input_image = gr.Image(
158
  label="Upload Logo Image",
159
  type="pil",
160
+ height=384,
161
+ tool=None # Disable editing tools to prevent corruption
162
  )
163
  prompt_input = gr.Textbox(
164
  label="Where should the logo be applied?",
 
168
  generate_btn = gr.Button("Generate Application", variant="primary")
169
 
170
  with gr.Column():
171
+ output_image = gr.Image(
172
+ label="Generated Application",
173
+ type="pil"
174
+ )
175
+ output_side = gr.Image(
176
+ label="Side by side",
177
+ type="pil"
178
+ )
179
+
180
  with gr.Row():
181
  gr.Markdown("""
182
  ### Instructions:
 
187
  Note: The generation process might take a few moments.
188
  """)
189
 
190
+ # Set up the click event with error handling
191
  generate_btn.click(
192
  fn=generate,
193
  inputs=[input_image, prompt_input],
194
+ outputs=[output_image, output_side],
195
+ api_name="generate"
196
  )
197
 
198
  # Launch the interface