multimodalart HF staff commited on
Commit
b4f042d
1 Parent(s): 75513f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -161
app.py CHANGED
@@ -2,12 +2,9 @@ import gradio as gr
2
  import torch
3
  import spaces
4
  from diffusers import FluxInpaintPipeline
5
- from PIL import Image #, ImageFile
6
- import io
7
- import numpy as np
8
 
9
- # Enable loading of truncated images
10
- # ImageFile.LOAD_TRUNCATED_IMAGES = True
11
 
12
  # Initialize the pipeline
13
  pipe = FluxInpaintPipeline.from_pretrained(
@@ -20,153 +17,64 @@ pipe.load_lora_weights(
20
  weight_name="visual-identity-design.safetensors"
21
  )
22
 
23
- def safe_open_image(image):
24
- """Safely open and validate image"""
25
- try:
26
- if isinstance(image, np.ndarray):
27
- # Convert numpy array to PIL Image
28
- image = Image.fromarray(image)
29
- elif isinstance(image, bytes):
30
- # Handle bytes input
31
- image = Image.open(io.BytesIO(image))
32
-
33
- # Ensure the image is in RGB mode
34
- if image.mode != 'RGB':
35
- image = image.convert('RGB')
36
-
37
- return image
38
- except Exception as e:
39
- raise ValueError(f"Error processing input image: {str(e)}")
40
-
41
  def square_center_crop(img, target_size=768):
42
- """Improved center crop with additional validation"""
43
- try:
44
- img = safe_open_image(img)
45
-
46
- # Ensure minimum size
47
- if img.size[0] < 64 or img.size[1] < 64:
48
- raise ValueError("Image is too small. Minimum size is 64x64 pixels.")
49
-
50
- width, height = img.size
51
- crop_size = min(width, height)
52
-
53
- # Calculate crop coordinates
54
- left = max(0, (width - crop_size) // 2)
55
- top = max(0, (height - crop_size) // 2)
56
- right = min(width, left + crop_size)
57
- bottom = min(height, top + crop_size)
58
-
59
- img_cropped = img.crop((left, top, right, bottom))
60
-
61
- # Use high-quality resizing
62
- return img_cropped.resize(
63
- (target_size, target_size),
64
- Image.Resampling.LANCZOS,
65
- reducing_gap=3.0
66
- )
67
- except Exception as e:
68
- raise ValueError(f"Error during image cropping: {str(e)}")
69
 
70
  def duplicate_horizontally(img):
71
- """Improved horizontal duplication with validation"""
72
- try:
73
- width, height = img.size
74
- if width != height:
75
- raise ValueError(f"Input image must be square, got {width}x{height}")
76
-
77
- # Create new image with RGB mode explicitly
78
- new_image = Image.new('RGB', (width * 2, height))
79
-
80
- # Ensure the source image is in RGB mode
81
- if img.mode != 'RGB':
82
- img = img.convert('RGB')
83
-
84
- new_image.paste(img, (0, 0))
85
- new_image.paste(img, (width, 0))
86
-
87
- return new_image
88
- except Exception as e:
89
- raise ValueError(f"Error during image duplication: {str(e)}")
90
-
91
- def safe_crop_output(img):
92
- """Safely crop the output image"""
93
- try:
94
- width, height = img.size
95
- half_width = width // 2
96
- return img.crop((half_width, 0, width, height))
97
- except Exception as e:
98
- raise ValueError(f"Error cropping output image: {str(e)}")
99
-
100
- # Load the mask image with error handling
101
- try:
102
- mask = Image.open("mask_square.png")
103
- if mask.mode != 'RGB':
104
- mask = mask.convert('RGB')
105
- except Exception as e:
106
- raise RuntimeError(f"Error loading mask image: {str(e)}")
107
 
108
  @spaces.GPU
109
  def generate(image, prompt_user, progress=gr.Progress(track_tqdm=True)):
110
- """Improved generation function with proper error handling"""
111
- try:
112
- if image is None:
113
- raise ValueError("No input image provided")
114
-
115
- if not prompt_user or prompt_user.strip() == "":
116
- raise ValueError("Please provide a prompt")
117
-
118
- prompt_structure = "The two-panel image showcases the logo of a brand, [LEFT] the left panel is showing the logo [RIGHT] the right panel has this logo applied to "
119
- prompt = prompt_structure + prompt_user
120
-
121
- # Process input image
122
- try:
123
- cropped_image = square_center_crop(image)
124
- except Exception as e:
125
- error_message = f"Error during cropping: {str(e)}"
126
- print(error_message) # For logging
127
- raise gr.Error(error_message)
128
-
129
- yield debug_resize, None, None, None
130
- print("Size after cropping", cropped_image.size)
131
-
132
- try:
133
- logo_dupli = duplicate_horizontally(cropped_image)
134
- except Exception as e:
135
- error_message = f"Error during duplication: {str(e)}"
136
- print(error_message) # For logging
137
- raise gr.Error(error_message)
138
- yield debug_resize, debug_duplicate, None, None
139
- print("just before getting into pipe")
140
- # Generate output
141
- out = pipe(
142
- prompt=prompt,
143
- image=logo_dupli,
144
- mask_image=mask,
145
- guidance_scale=6,
146
- height=768,
147
- width=1536,
148
- num_inference_steps=28,
149
- max_sequence_length=256,
150
- strength=1
151
- ).images[0]
152
-
153
- # First yield for progress
154
- yield debug_resize, debug_duplicate, out, None
155
-
156
- # Process and return final output
157
- image_2 = safe_crop_output(out)
158
- yield debug_resize, debug_duplicate, out, image_2
159
-
160
- except Exception as e:
161
- error_message = f"Error during generation: {str(e)}"
162
- print(error_message) # For logging
163
- raise gr.Error(error_message)
164
-
165
- # Create the Gradio interface
166
  with gr.Blocks() as demo:
167
  gr.Markdown("# Logo in Context")
168
  gr.Markdown("### In-Context LoRA + Image-to-Image, apply your logo to anything")
169
-
170
  with gr.Row():
171
  with gr.Column():
172
  input_image = gr.Image(
@@ -180,37 +88,27 @@ with gr.Blocks() as demo:
180
  lines=2
181
  )
182
  generate_btn = gr.Button("Generate Application", variant="primary")
183
-
184
  with gr.Column():
185
- output_image = gr.Image(
186
- label="Generated Application",
187
- type="pil"
188
- )
189
- output_side = gr.Image(
190
- label="Side by side",
191
- type="pil"
192
- )
193
- debug_resize = gr.Image()
194
- debug_duplicate = gr.Image()
195
-
196
  with gr.Row():
197
  gr.Markdown("""
198
  ### Instructions:
199
  1. Upload a logo image (preferably square)
200
  2. Describe where you'd like to see the logo applied
201
  3. Click 'Generate Application' and wait for the result
202
-
203
  Note: The generation process might take a few moments.
204
  """)
205
-
206
- # Set up the click event with error handling
207
  generate_btn.click(
208
  fn=generate,
209
  inputs=[input_image, prompt_input],
210
- outputs=[debug_resize, debug_duplicate, output_side, output_image],
211
- api_name="generate"
212
  )
213
 
214
  # Launch the interface
215
- if __name__ == "__main__":
216
  demo.launch()
 
2
  import torch
3
  import spaces
4
  from diffusers import FluxInpaintPipeline
5
+ from PIL import Image, ImageFile
 
 
6
 
7
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
 
8
 
9
  # Initialize the pipeline
10
  pipe = FluxInpaintPipeline.from_pretrained(
 
17
  weight_name="visual-identity-design.safetensors"
18
  )
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def square_center_crop(img, target_size=768):
21
+ if img.mode in ('RGBA', 'P'):
22
+ img = img.convert('RGB')
23
+
24
+ width, height = img.size
25
+ crop_size = min(width, height)
26
+
27
+ left = (width - crop_size) // 2
28
+ top = (height - crop_size) // 2
29
+ right = left + crop_size
30
+ bottom = top + crop_size
31
+
32
+ img_cropped = img.crop((left, top, right, bottom))
33
+ return img_cropped.resize((target_size, target_size), Image.Resampling.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def duplicate_horizontally(img):
36
+ width, height = img.size
37
+ if width != height:
38
+ raise ValueError(f"Input image must be square, got {width}x{height}")
39
+
40
+ new_image = Image.new('RGB', (width * 2, height))
41
+ new_image.paste(img, (0, 0))
42
+ new_image.paste(img, (width, 0))
43
+ return new_image
44
+
45
+ # Load the mask image
46
+ mask = Image.open("mask_square.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  @spaces.GPU
49
  def generate(image, prompt_user, progress=gr.Progress(track_tqdm=True)):
50
+ prompt_structure = "The two-panel image showcases the logo of a brand, [LEFT] the left panel is showing the logo [RIGHT] the right panel has this logo applied to "
51
+ prompt = prompt_structure + prompt_user
52
+
53
+ cropped_image = square_center_crop(image)
54
+ logo_dupli = duplicate_horizontally(cropped_image)
55
+
56
+ out = pipe(
57
+ prompt=prompt,
58
+ image=logo_dupli,
59
+ mask_image=mask,
60
+ guidance_scale=6,
61
+ height=768,
62
+ width=1536,
63
+ num_inference_steps=28,
64
+ max_sequence_length=256,
65
+ strength=1
66
+ ).images[0]
67
+
68
+ yield None, out
69
+ width, height = out.size
70
+ half_width = width // 2
71
+ image_2 = out.crop((half_width, 0, width, height))
72
+ yield image_2, out
73
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  with gr.Blocks() as demo:
75
  gr.Markdown("# Logo in Context")
76
  gr.Markdown("### In-Context LoRA + Image-to-Image, apply your logo to anything")
77
+
78
  with gr.Row():
79
  with gr.Column():
80
  input_image = gr.Image(
 
88
  lines=2
89
  )
90
  generate_btn = gr.Button("Generate Application", variant="primary")
91
+
92
  with gr.Column():
93
+ output_image = gr.Image(label="Generated Application")
94
+ output_side = gr.Image(label="Side by side")
 
 
 
 
 
 
 
 
 
95
  with gr.Row():
96
  gr.Markdown("""
97
  ### Instructions:
98
  1. Upload a logo image (preferably square)
99
  2. Describe where you'd like to see the logo applied
100
  3. Click 'Generate Application' and wait for the result
101
+
102
  Note: The generation process might take a few moments.
103
  """)
104
+
105
+ # Set up the click event
106
  generate_btn.click(
107
  fn=generate,
108
  inputs=[input_image, prompt_input],
109
+ outputs=[output_image, output_side]
 
110
  )
111
 
112
  # Launch the interface
113
+ if name == "main":
114
  demo.launch()