Gopalag commited on
Commit
9693fed
·
verified ·
1 Parent(s): 59a7070

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -30
app.py CHANGED
@@ -4,6 +4,8 @@ import random
4
  import spaces
5
  import torch
6
  from diffusers import DiffusionPipeline
 
 
7
 
8
  dtype = torch.bfloat16
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -16,6 +18,41 @@ pipe = DiffusionPipeline.from_pretrained(
16
  MAX_SEED = np.iinfo(np.int32).max
17
  MAX_IMAGE_SIZE = 2048
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def enhance_prompt_for_tshirt(prompt, style=None):
20
  """Add specific terms to ensure good t-shirt designs."""
21
  style_terms = {
@@ -27,9 +64,9 @@ def enhance_prompt_for_tshirt(prompt, style=None):
27
  }
28
 
29
  base_terms = [
30
- "t-shirt design",
31
- "centered composition",
32
- "high quality",
33
  "professional design",
34
  "clear background"
35
  ]
@@ -43,14 +80,17 @@ def enhance_prompt_for_tshirt(prompt, style=None):
43
  return enhanced_prompt
44
 
45
  @spaces.GPU()
46
- def infer(prompt, style=None, seed=42, randomize_seed=False, width=1024, height=1024,
47
- num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
 
48
  if randomize_seed:
49
  seed = random.randint(0, MAX_SEED)
50
 
51
  enhanced_prompt = enhance_prompt_for_tshirt(prompt, style)
52
  generator = torch.Generator().manual_seed(seed)
53
- image = pipe(
 
 
54
  prompt=enhanced_prompt,
55
  width=width,
56
  height=height,
@@ -59,14 +99,24 @@ def infer(prompt, style=None, seed=42, randomize_seed=False, width=1024, height=
59
  guidance_scale=0.0
60
  ).images[0]
61
 
62
- return image, seed
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  examples = [
65
- ["Cool geometric mountain landscape", "minimal"],
66
- ["Vintage motorcycle with flames", "vintage"],
67
- ["Abstract watercolor butterfly", "artistic"],
68
- ["Sacred geometry mandala", "geometric"],
69
- ["Adventure Awaits typography", "typography"],
70
  ]
71
 
72
  styles = [
@@ -105,13 +155,19 @@ css = """
105
  font-size: 1rem;
106
  transition: all 0.3s ease;
107
  }
 
 
 
 
 
 
108
  """
109
 
110
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
111
  with gr.Column(elem_id="col-container"):
112
  gr.Markdown(
113
  """
114
- # 👕 T-Shirt Design Generator
115
  """,
116
  elem_classes=["main-title"]
117
  )
@@ -142,25 +198,30 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
142
  label="Style",
143
  container=False
144
  )
 
 
 
 
 
 
 
145
  run_button = gr.Button(
146
  "✨ Generate",
147
  scale=0,
148
  elem_classes=["generate-button"]
149
  )
150
 
151
- with gr.Row():
152
- with gr.Column():
153
- result = gr.Image(
154
- label="Generated Design",
155
- show_label=True,
156
- elem_classes=["result-image"]
157
- )
158
- with gr.Column():
159
- preview = gr.Image(
160
- label="T-Shirt Preview",
161
- show_label=True,
162
- elem_classes=["preview-image"]
163
- )
164
 
165
  with gr.Accordion("🔧 Advanced Settings", open=False):
166
  with gr.Group():
@@ -203,16 +264,16 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
203
  gr.Examples(
204
  examples=examples,
205
  fn=infer,
206
- inputs=[prompt, style],
207
- outputs=[result, seed],
208
  cache_examples=True
209
  )
210
 
211
  gr.on(
212
  triggers=[run_button.click, prompt.submit],
213
  fn=infer,
214
- inputs=[prompt, style, seed, randomize_seed, width, height, num_inference_steps],
215
- outputs=[result, seed]
216
  )
217
 
218
  demo.launch()
 
4
  import spaces
5
  import torch
6
  from diffusers import DiffusionPipeline
7
+ from PIL import Image
8
+ import io
9
 
10
  dtype = torch.bfloat16
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 2048
20
 
21
+ def create_tshirt_preview(design_image, tshirt_color="white"):
22
+ """
23
+ Overlay the design onto a t-shirt template
24
+ """
25
+ # Create a base t-shirt shape
26
+ tshirt_width = 800
27
+ tshirt_height = 1000
28
+
29
+ # Create base t-shirt image
30
+ tshirt = Image.new('RGB', (tshirt_width, tshirt_height), tshirt_color)
31
+
32
+ # Convert design to PIL Image if it's not already
33
+ if not isinstance(design_image, Image.Image):
34
+ design_image = Image.fromarray(design_image)
35
+
36
+ # Resize design to fit nicely on shirt (30% of shirt width)
37
+ design_width = int(tshirt_width * 0.3)
38
+ design_height = int(design_width * design_image.size[1] / design_image.size[0])
39
+ design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS)
40
+
41
+ # Calculate position to center design on shirt (top third of shirt)
42
+ x = (tshirt_width - design_width) // 2
43
+ y = int(tshirt_height * 0.25) # Position in top third
44
+
45
+ # If design has transparency (RGBA), create mask
46
+ if design_image.mode == 'RGBA':
47
+ mask = design_image.split()[3]
48
+ else:
49
+ mask = None
50
+
51
+ # Paste design onto shirt
52
+ tshirt.paste(design_image, (x, y), mask)
53
+
54
+ return tshirt
55
+
56
  def enhance_prompt_for_tshirt(prompt, style=None):
57
  """Add specific terms to ensure good t-shirt designs."""
58
  style_terms = {
 
64
  }
65
 
66
  base_terms = [
67
+ "create a t-shirt design",
68
+ "with centered composition",
69
+ "4k high quality",
70
  "professional design",
71
  "clear background"
72
  ]
 
80
  return enhanced_prompt
81
 
82
  @spaces.GPU()
83
+ def infer(prompt, style=None, tshirt_color="white", seed=42, randomize_seed=False,
84
+ width=1024, height=1024, num_inference_steps=4,
85
+ progress=gr.Progress(track_tqdm=True)):
86
  if randomize_seed:
87
  seed = random.randint(0, MAX_SEED)
88
 
89
  enhanced_prompt = enhance_prompt_for_tshirt(prompt, style)
90
  generator = torch.Generator().manual_seed(seed)
91
+
92
+ # Generate the design
93
+ design_image = pipe(
94
  prompt=enhanced_prompt,
95
  width=width,
96
  height=height,
 
99
  guidance_scale=0.0
100
  ).images[0]
101
 
102
+ # Create t-shirt preview
103
+ tshirt_preview = create_tshirt_preview(design_image, tshirt_color)
104
+
105
+ return design_image, tshirt_preview, seed
106
+
107
+ # Available t-shirt colors
108
+ TSHIRT_COLORS = {
109
+ "White": "#FFFFFF",
110
+ "Black": "#000000",
111
+ "Navy": "#000080",
112
+ "Gray": "#808080"
113
+ }
114
 
115
  examples = [
116
+ ["Cool geometric mountain landscape", "minimal", "White"],
117
+ ["Vintage motorcycle with flames", "vintage", "Black"],
118
+ ["Abstract watercolor butterfly in forest", "artistic", "White"],
119
+ ["Adventure Awaits typography", "typography", "Gray"]
 
120
  ]
121
 
122
  styles = [
 
155
  font-size: 1rem;
156
  transition: all 0.3s ease;
157
  }
158
+ .results-row {
159
+ display: grid;
160
+ grid-template-columns: 1fr 1fr;
161
+ gap: 20px;
162
+ margin-top: 20px;
163
+ }
164
  """
165
 
166
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
167
  with gr.Column(elem_id="col-container"):
168
  gr.Markdown(
169
  """
170
+ # 👕 Deradh's T-Shirt Design Generator
171
  """,
172
  elem_classes=["main-title"]
173
  )
 
198
  label="Style",
199
  container=False
200
  )
201
+ with gr.Column(scale=1):
202
+ tshirt_color = gr.Dropdown(
203
+ choices=list(TSHIRT_COLORS.keys()),
204
+ value="White",
205
+ label="T-Shirt Color",
206
+ container=False
207
+ )
208
  run_button = gr.Button(
209
  "✨ Generate",
210
  scale=0,
211
  elem_classes=["generate-button"]
212
  )
213
 
214
+ with gr.Row(elem_classes=["results-row"]):
215
+ result = gr.Image(
216
+ label="Generated Design",
217
+ show_label=True,
218
+ elem_classes=["result-image"]
219
+ )
220
+ preview = gr.Image(
221
+ label="T-Shirt Preview",
222
+ show_label=True,
223
+ elem_classes=["preview-image"]
224
+ )
 
 
225
 
226
  with gr.Accordion("🔧 Advanced Settings", open=False):
227
  with gr.Group():
 
264
  gr.Examples(
265
  examples=examples,
266
  fn=infer,
267
+ inputs=[prompt, style, tshirt_color],
268
+ outputs=[result, preview, seed],
269
  cache_examples=True
270
  )
271
 
272
  gr.on(
273
  triggers=[run_button.click, prompt.submit],
274
  fn=infer,
275
+ inputs=[prompt, style, tshirt_color, seed, randomize_seed, width, height, num_inference_steps],
276
+ outputs=[result, preview, seed]
277
  )
278
 
279
  demo.launch()