Gopalag commited on
Commit
8a2fb71
·
verified ·
1 Parent(s): 8b0903d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -260
app.py CHANGED
@@ -1,143 +1,132 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- import spaces
5
  import torch
6
- from diffusers import DiffusionPipeline
7
  from PIL import Image
8
- import io
9
- from PIL import ImageEnhance
 
 
10
 
11
  def get_edge_color(image):
12
- """
13
- Get a random color from the edge of the image
14
- """
15
- # Convert to numpy array
16
  img_array = np.array(image)
17
-
18
- # Get pixels from all edges
19
  top_edge = img_array[0, :, :]
20
  bottom_edge = img_array[-1, :, :]
21
  left_edge = img_array[:, 0, :]
22
  right_edge = img_array[:, -1, :]
23
-
24
- # Combine all edge pixels
25
  edge_pixels = np.concatenate([top_edge, bottom_edge, left_edge, right_edge])
26
-
27
- # Pick a random edge pixel
28
  random_edge_color = tuple(edge_pixels[random.randint(0, len(edge_pixels)-1)])
29
-
30
  return random_edge_color
31
 
32
  def color_match_tshirt(tshirt_image, target_color, threshold=30):
33
- """
34
- Change white/near-white areas of the t-shirt to the target color
35
- """
36
- # Convert to numpy array
37
  img_array = np.array(tshirt_image)
38
-
39
- # Create a mask for near-white pixels
40
  white_mask = np.all(np.abs(img_array - [255, 255, 255]) < threshold, axis=2)
41
-
42
- # Apply the new color to masked areas
43
  img_array[white_mask] = target_color
44
-
45
  return Image.fromarray(img_array)
46
 
47
- def add_logo_watermark(image, logo_path='logo.png', size_percentage=0.2):
48
- """
49
- Add a logo watermark to the bottom right corner
50
- """
51
- # Open and resize logo
52
- logo = Image.open(logo_path)
53
-
54
- # Calculate new logo size (20% of main image width)
55
- new_width = int(image.size[0] * size_percentage)
56
- new_height = int(new_width * logo.size[1] / logo.size[0])
57
- logo = logo.resize((new_width, new_height), Image.Resampling.LANCZOS)
58
-
59
- # If logo has transparency, use it as mask
60
- if logo.mode == 'RGBA':
61
- mask = logo.split()[3]
62
- else:
63
- mask = None
64
-
65
- # Calculate position (bottom right corner with padding)
66
- position = (image.size[0] - logo.size[0] - 10,
67
- image.size[1] - logo.size[1] - 10)
68
-
69
- # Create a copy of the image
70
- result = image.copy()
71
- result.paste(logo, position, mask)
72
-
73
- return result
 
 
 
 
 
74
 
75
- def create_tshirt_preview(design_image, tshirt_color="white"):
76
- """
77
- Overlay the design onto the t-shirt template with color matching
78
- """
79
- # Load the template t-shirt image
80
- tshirt = Image.open('image.jpeg')
81
- tshirt_width, tshirt_height = tshirt.size
82
-
83
- # Get a random edge color from the design
84
- edge_color = get_edge_color(design_image)
85
-
86
- # Color match the t-shirt
87
- tshirt = color_match_tshirt(tshirt, edge_color)
88
-
89
- # Resize design (35% of shirt width)
90
- design_width = int(tshirt_width * 0.35)
91
- design_height = int(design_width * design_image.size[1] / design_image.size[0])
92
- design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS)
93
-
94
- # Calculate position to center design
95
- x = (tshirt_width - design_width) // 2
96
- y = int(tshirt_height * 0.2)
97
-
98
- # Create mask if design has transparency
99
- if design_image.mode == 'RGBA':
100
- mask = design_image.split()[3]
101
- else:
102
- mask = None
103
-
104
- # Paste design onto shirt
105
- tshirt.paste(design_image, (x, y), mask)
106
-
107
- # Add logo watermark
108
- tshirt = add_logo_watermark(tshirt)
109
-
110
- return tshirt
111
 
112
- @spaces.GPU()
113
- def infer(prompt, style=None, tshirt_color="white", seed=42, randomize_seed=False,
114
- width=1024, height=1024, num_inference_steps=4,
115
- progress=gr.Progress(track_tqdm=True)):
116
- if randomize_seed:
117
- seed = random.randint(0, MAX_SEED)
118
-
119
- enhanced_prompt = enhance_prompt_for_tshirt(prompt, style)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  generator = torch.Generator().manual_seed(seed)
121
 
122
- # Generate the design
123
- design_image = pipe(
124
  prompt=enhanced_prompt,
125
  width=width,
126
  height=height,
127
  num_inference_steps=num_inference_steps,
128
  generator=generator,
129
- guidance_scale=0.0
130
  ).images[0]
131
 
132
- # Add logo to design
133
- design_image = add_logo_watermark(design_image)
134
-
135
- # Create t-shirt preview
136
- tshirt_preview = create_tshirt_preview(design_image, tshirt_color)
137
-
138
- return design_image, tshirt_preview, seed
139
 
140
- # Available t-shirt colors
141
  TSHIRT_COLORS = {
142
  "White": "#FFFFFF",
143
  "Black": "#000000",
@@ -145,14 +134,7 @@ TSHIRT_COLORS = {
145
  "Gray": "#808080"
146
  }
147
 
148
- examples = [
149
- ["Cool geometric mountain landscape", "minimal", "White"],
150
- ["Vintage motorcycle with flames", "vintage", "Black"],
151
- ["flamingo in scenic forset", "realistic", "White"],
152
- ["Adventure Starts typography", "typography", "White"]
153
- ]
154
-
155
- styles = [
156
  "minimal",
157
  "vintage",
158
  "artistic",
@@ -161,153 +143,67 @@ styles = [
161
  "realistic"
162
  ]
163
 
164
- css = """
165
- #col-container {
166
- margin: 0 auto;
167
- max-width: 1200px !important;
168
- padding: 20px;
169
- }
170
- .main-title {
171
- text-align: center;
172
- color: #2d3748;
173
- margin-bottom: 1rem;
174
- font-family: 'Poppins', sans-serif;
175
- }
176
- .subtitle {
177
- text-align: center;
178
- color: #4a5568;
179
- margin-bottom: 2rem;
180
- font-family: 'Inter', sans-serif;
181
- font-size: 0.95rem;
182
- line-height: 1.5;
183
- }
184
- .design-input {
185
- border: 2px solid #e2e8f0;
186
- border-radius: 10px;
187
- padding: 12px !important;
188
- margin-bottom: 1rem !important;
189
- font-size: 1rem;
190
- transition: all 0.3s ease;
191
- }
192
- .results-row {
193
- display: grid;
194
- grid-template-columns: 1fr 1fr;
195
- gap: 20px;
196
- margin-top: 20px;
197
- }
198
- """
199
 
200
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
201
- with gr.Column(elem_id="col-container"):
202
- gr.Markdown(
203
- """
204
- # 👕Deradh's T-Shirt Design Generator
205
- """,
206
- elem_classes=["main-title"]
207
- )
208
-
209
- gr.Markdown(
210
- """
211
- Create unique t-shirt designs using Deradh's AI.
212
- Describe your design idea and select a style to generate professional-quality artwork
213
- perfect for custom t-shirts.
214
- """,
215
- elem_classes=["subtitle"]
216
- )
217
-
218
- with gr.Row():
219
- with gr.Column(scale=2):
220
- prompt = gr.Text(
221
- label="Design Description",
222
- show_label=False,
223
- max_lines=1,
224
- placeholder="Describe your t-shirt design idea",
225
- container=False,
226
- elem_classes=["design-input"]
227
- )
228
- with gr.Column(scale=1):
229
- style = gr.Dropdown(
230
- choices=[""] + styles,
231
- value="",
232
- label="Style",
233
- container=False
234
- )
235
- with gr.Column(scale=1):
236
- tshirt_color = gr.Dropdown(
237
- choices=list(TSHIRT_COLORS.keys()),
238
- value="White",
239
- label="T-Shirt Color",
240
- container=False
241
- )
242
- run_button = gr.Button(
243
- "✨ Generate",
244
- scale=0,
245
- elem_classes=["generate-button"]
246
- )
247
-
248
- with gr.Row(elem_classes=["results-row"]):
249
- result = gr.Image(
250
- label="Generated Design",
251
- show_label=True,
252
- elem_classes=["result-image"]
253
  )
254
- preview = gr.Image(
255
- label="T-Shirt Preview",
256
- show_label=True,
257
- elem_classes=["preview-image"]
 
258
  )
259
-
260
- with gr.Accordion("🔧 Advanced Settings", open=False):
261
- with gr.Group():
262
- seed = gr.Slider(
263
- label="Design Seed",
264
- minimum=0,
265
- maximum=MAX_SEED,
266
- step=1,
267
- value=0,
268
- )
269
- randomize_seed = gr.Checkbox(
270
- label="Randomize Design",
271
- value=True
272
- )
273
-
274
- with gr.Row():
275
- width = gr.Slider(
276
- label="Width",
277
- minimum=256,
278
- maximum=MAX_IMAGE_SIZE,
279
- step=32,
280
- value=1024,
281
- )
282
- height = gr.Slider(
283
- label="Height",
284
- minimum=256,
285
- maximum=MAX_IMAGE_SIZE,
286
- step=32,
287
- value=1024,
288
- )
289
-
290
- num_inference_steps = gr.Slider(
291
- label="Generation Quality (Steps)",
292
- minimum=1,
293
- maximum=50,
294
- step=1,
295
- value=4,
296
- )
297
-
298
- gr.Examples(
299
- examples=examples,
300
- fn=infer,
301
- inputs=[prompt, style, tshirt_color],
302
- outputs=[result, preview, seed],
303
- cache_examples=True
304
- )
305
-
306
- gr.on(
307
- triggers=[run_button.click, prompt.submit],
308
- fn=infer,
309
- inputs=[prompt, style, tshirt_color, seed, randomize_seed, width, height, num_inference_steps],
310
- outputs=[result, preview, seed]
311
- )
312
 
313
- demo.launch()
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
4
  import torch
5
+ from diffusers import StableDiffusionPipeline
6
  from PIL import Image
7
+ import os
8
+
9
+ MAX_SEED = 10000
10
+ MAX_IMAGE_SIZE = 1024
11
 
12
  def get_edge_color(image):
13
+ """Get a random color from the edge of the image"""
 
 
 
14
  img_array = np.array(image)
 
 
15
  top_edge = img_array[0, :, :]
16
  bottom_edge = img_array[-1, :, :]
17
  left_edge = img_array[:, 0, :]
18
  right_edge = img_array[:, -1, :]
 
 
19
  edge_pixels = np.concatenate([top_edge, bottom_edge, left_edge, right_edge])
 
 
20
  random_edge_color = tuple(edge_pixels[random.randint(0, len(edge_pixels)-1)])
 
21
  return random_edge_color
22
 
23
  def color_match_tshirt(tshirt_image, target_color, threshold=30):
24
+ """Change white/near-white areas of the t-shirt to the target color"""
 
 
 
25
  img_array = np.array(tshirt_image)
 
 
26
  white_mask = np.all(np.abs(img_array - [255, 255, 255]) < threshold, axis=2)
 
 
27
  img_array[white_mask] = target_color
 
28
  return Image.fromarray(img_array)
29
 
30
+ def add_watermark(image, logo_path, position='bottom-right', size_percentage=10):
31
+ """Add a watermark to an image"""
32
+ try:
33
+ if not os.path.exists(logo_path):
34
+ return image
35
+
36
+ logo = Image.open(logo_path).convert('RGBA')
37
+ main_width, main_height = image.size
38
+ logo_width = int(main_width * size_percentage / 100)
39
+ logo_height = int(logo.size[1] * (logo_width / logo.size[0]))
40
+ logo = logo.resize((logo_width, logo_height), Image.Resampling.LANCZOS)
41
+
42
+ if image.mode != 'RGBA':
43
+ image = image.convert('RGBA')
44
+
45
+ watermarked = Image.new('RGBA', image.size, (0, 0, 0, 0))
46
+ watermarked.paste(image, (0, 0))
47
+
48
+ if position == 'bottom-right':
49
+ pos = (main_width - logo_width - 10, main_height - logo_height - 10)
50
+ elif position == 'bottom-left':
51
+ pos = (10, main_height - logo_height - 10)
52
+ elif position == 'top-right':
53
+ pos = (main_width - logo_width - 10, 10)
54
+ else: # top-left
55
+ pos = (10, 10)
56
+
57
+ watermarked.paste(logo, pos, logo)
58
+ return watermarked.convert('RGB')
59
+ except Exception as e:
60
+ print(f"Failed to add watermark: {str(e)}")
61
+ return image
62
 
63
+ def create_tshirt_preview(design_image, tshirt_template_path, tshirt_color="white"):
64
+ """Create a preview of the design on a t-shirt"""
65
+ try:
66
+ tshirt = Image.open(tshirt_template_path)
67
+ tshirt_width, tshirt_height = tshirt.size
68
+
69
+ edge_color = get_edge_color(design_image)
70
+ tshirt = color_match_tshirt(tshirt, edge_color)
71
+
72
+ design_width = int(tshirt_width * 0.35)
73
+ design_height = int(design_width * design_image.size[1] / design_image.size[0])
74
+ design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS)
75
+
76
+ x = (tshirt_width - design_width) // 2
77
+ y = int(tshirt_height * 0.2)
78
+
79
+ if design_image.mode == 'RGBA':
80
+ mask = design_image.split()[3]
81
+ else:
82
+ mask = None
83
+
84
+ tshirt.paste(design_image, (x, y), mask)
85
+ return tshirt
86
+ except Exception as e:
87
+ print(f"Failed to create t-shirt preview: {str(e)}")
88
+ return design_image
 
 
 
 
 
 
 
 
 
 
89
 
90
+ def enhance_prompt(prompt, style):
91
+ """Enhance the prompt based on selected style"""
92
+ if not style:
93
+ return prompt
94
+
95
+ style_prompts = {
96
+ "minimal": "minimalist design, clean lines, simple shapes",
97
+ "vintage": "vintage style, retro, distressed texture",
98
+ "artistic": "artistic, creative, hand-drawn style",
99
+ "geometric": "geometric patterns, abstract shapes",
100
+ "typography": "modern typography, creative lettering",
101
+ "realistic": "photorealistic, detailed illustration"
102
+ }
103
+
104
+ return f"{prompt}, {style_prompts.get(style, '')}"
105
+
106
+ def initialize_pipeline():
107
+ """Initialize the Stable Diffusion pipeline"""
108
+ model_id = "stabilityai/stable-diffusion-2-1"
109
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
110
+ if torch.cuda.is_available():
111
+ pipe = pipe.to("cuda")
112
+ return pipe
113
+
114
+ def generate_design(prompt, style, seed, width, height, num_inference_steps, pipe):
115
+ """Generate the design using Stable Diffusion"""
116
+ enhanced_prompt = enhance_prompt(prompt, style)
117
  generator = torch.Generator().manual_seed(seed)
118
 
119
+ image = pipe(
 
120
  prompt=enhanced_prompt,
121
  width=width,
122
  height=height,
123
  num_inference_steps=num_inference_steps,
124
  generator=generator,
 
125
  ).images[0]
126
 
127
+ return image
 
 
 
 
 
 
128
 
129
+ # Constants
130
  TSHIRT_COLORS = {
131
  "White": "#FFFFFF",
132
  "Black": "#000000",
 
134
  "Gray": "#808080"
135
  }
136
 
137
+ STYLES = [
 
 
 
 
 
 
 
138
  "minimal",
139
  "vintage",
140
  "artistic",
 
143
  "realistic"
144
  ]
145
 
146
+ EXAMPLES = [
147
+ ["Cool geometric mountain landscape", "minimal", "White"],
148
+ ["Vintage motorcycle with flames", "vintage", "Black"],
149
+ ["Flamingo in scenic forest", "realistic", "White"],
150
+ ["Adventure Starts typography", "typography", "White"]
151
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ # Gradio Interface
154
+ def create_interface():
155
+ pipe = initialize_pipeline()
156
+
157
+ def infer(prompt, style, tshirt_color, seed, randomize_seed, width, height, num_inference_steps):
158
+ if randomize_seed:
159
+ seed = random.randint(0, MAX_SEED)
160
+
161
+ try:
162
+ design_image = generate_design(prompt, style, seed, width, height, num_inference_steps, pipe)
163
+ tshirt_preview = create_tshirt_preview(design_image, "tshirt_template.png", tshirt_color)
164
+ return design_image, tshirt_preview, seed
165
+ except Exception as e:
166
+ print(f"Error during inference: {str(e)}")
167
+ return None, None, seed
168
+
169
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
170
+ with gr.Column():
171
+ gr.Markdown("# 👕 T-Shirt Design Generator")
172
+
173
+ with gr.Row():
174
+ prompt = gr.Textbox(label="Design Description", placeholder="Describe your t-shirt design idea")
175
+ style = gr.Dropdown(choices=[""] + STYLES, value="", label="Style")
176
+ tshirt_color = gr.Dropdown(choices=list(TSHIRT_COLORS.keys()), value="White", label="T-Shirt Color")
177
+
178
+ run_button = gr.Button("✨ Generate")
179
+
180
+ with gr.Row():
181
+ result = gr.Image(label="Generated Design")
182
+ preview = gr.Image(label="T-Shirt Preview")
183
+
184
+ with gr.Accordion("Advanced Settings", open=False):
185
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
186
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
187
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
188
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
189
+ num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=25)
190
+
191
+ gr.Examples(
192
+ examples=EXAMPLES,
193
+ inputs=[prompt, style, tshirt_color],
194
+ outputs=[result, preview, seed],
195
+ fn=lambda p, s, c: infer(p, s, c, 0, True, 512, 512, 25),
196
+ cache_examples=True
 
 
 
 
 
 
 
 
 
197
  )
198
+
199
+ run_button.click(
200
+ fn=infer,
201
+ inputs=[prompt, style, tshirt_color, seed, randomize_seed, width, height, num_inference_steps],
202
+ outputs=[result, preview, seed]
203
  )
204
+
205
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ if __name__ == "__main__":
208
+ demo = create_interface()
209
+ demo.launch()