eiji commited on
Commit
11992a9
·
1 Parent(s): 230b53c

fix version error

Browse files
Files changed (2) hide show
  1. app.py +82 -438
  2. requirements.txt +2 -2
app.py CHANGED
@@ -2,109 +2,9 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  from PIL import Image, ImageDraw
5
- import json
6
  from tkg_dm import TKGDMPipeline
7
 
8
 
9
- def create_canvas_image(width=512, height=512):
10
- """Create a blank canvas for drawing bounding boxes"""
11
- img = Image.new('RGB', (width, height), (240, 240, 240)) # Light gray background
12
- draw = ImageDraw.Draw(img)
13
-
14
- # Add grid lines for better visualization
15
- grid_size = 64
16
- for x in range(0, width, grid_size):
17
- draw.line([(x, 0), (x, height)], fill=(200, 200, 200), width=1)
18
- for y in range(0, height, grid_size):
19
- draw.line([(0, y), (width, y)], fill=(200, 200, 200), width=1)
20
-
21
- # Add instructions
22
- draw.text((10, 10), "Draw bounding boxes to define reserved regions", fill=(100, 100, 100))
23
- draw.text((10, 25), "Click and drag to create boxes", fill=(100, 100, 100))
24
- draw.text((10, 40), "Use 'Clear Boxes' to reset", fill=(100, 100, 100))
25
-
26
- return img
27
-
28
- def draw_boxes_on_canvas(boxes, width=512, height=512):
29
- """Draw bounding boxes on canvas"""
30
- img = create_canvas_image(width, height)
31
- draw = ImageDraw.Draw(img)
32
-
33
- for i, (x1, y1, x2, y2) in enumerate(boxes):
34
- # Convert normalized coordinates to pixel coordinates
35
- px1, py1 = int(x1 * width), int(y1 * height)
36
- px2, py2 = int(x2 * width), int(y2 * height)
37
-
38
- # Draw bounding box
39
- draw.rectangle([px1, py1, px2, py2], outline='red', width=3)
40
- draw.rectangle([px1+1, py1+1, px2-1, py2-1], outline='yellow', width=2)
41
-
42
- # Add semi-transparent fill
43
- overlay = Image.new('RGBA', (width, height), (0, 0, 0, 0))
44
- overlay_draw = ImageDraw.Draw(overlay)
45
- overlay_draw.rectangle([px1, py1, px2, py2], fill=(255, 0, 0, 50))
46
- img = Image.alpha_composite(img.convert('RGBA'), overlay).convert('RGB')
47
- draw = ImageDraw.Draw(img)
48
-
49
- # Add box label
50
- label = f"Box {i+1}"
51
- draw.text((px1+5, py1+5), label, fill='white')
52
- draw.text((px1+4, py1+4), label, fill='black') # Shadow effect
53
-
54
- return img
55
-
56
- def add_bounding_box(bbox_str, x1, y1, x2, y2):
57
- """Add a new bounding box to the string"""
58
- # Ensure coordinates are in correct order and valid range
59
- x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2))
60
- y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2))
61
-
62
- # Check minimum size
63
- if x2 - x1 < 0.02 or y2 - y1 < 0.02:
64
- return bbox_str, sync_text_to_canvas(bbox_str)
65
-
66
- new_box = f"{x1:.3f},{y1:.3f},{x2:.3f},{y2:.3f}"
67
-
68
- if bbox_str.strip():
69
- updated_str = bbox_str + ";" + new_box
70
- else:
71
- updated_str = new_box
72
-
73
- return updated_str, sync_text_to_canvas(updated_str)
74
-
75
- def remove_last_box(bbox_str):
76
- """Remove the last bounding box"""
77
- if not bbox_str.strip():
78
- return "", create_canvas_image()
79
-
80
- boxes = bbox_str.split(';')
81
- if boxes:
82
- boxes.pop()
83
-
84
- updated_str = ';'.join(boxes)
85
- return updated_str, sync_text_to_canvas(updated_str)
86
-
87
- def create_box_builder_interface():
88
- """Create a user-friendly box building interface"""
89
- return """
90
- <div style="background: #f8f9fa; padding: 20px; border-radius: 8px; border: 1px solid #dee2e6;">
91
- <h4 style="margin-top: 0; color: #495057;">📦 Bounding Box Builder</h4>
92
- <p style="color: #6c757d; margin-bottom: 15px;">
93
- Define reserved regions where content generation will be suppressed. Use coordinate inputs for precision.
94
- </p>
95
- <div style="background: white; padding: 15px; border-radius: 6px; border: 1px solid #ced4da; margin-bottom: 15px;">
96
- <strong>Instructions:</strong><br>
97
- • Each box is defined by (x1, y1, x2, y2) where coordinates range from 0.0 to 1.0<br>
98
- • (0,0) is top-left corner, (1,1) is bottom-right corner<br>
99
- • Multiple boxes are separated by semicolons<br>
100
- • Red/yellow boxes in preview show reserved regions
101
- </div>
102
- <div style="background: #e7f3ff; padding: 10px; border-radius: 6px; border: 1px solid #b3d9ff;">
103
- <strong>💡 Tips:</strong> Start with default values (0.2,0.2,0.8,0.4) for a center box, then adjust coordinates as needed.
104
- </div>
105
- </div>
106
- """
107
-
108
  def load_preset_boxes(preset_name):
109
  """Load preset bounding box configurations"""
110
  presets = {
@@ -117,25 +17,9 @@ def load_preset_boxes(preset_name):
117
  }
118
  return presets.get(preset_name, "")
119
 
120
- def extract_boxes_from_annotated_image(annotated_data):
121
- """Extract bounding boxes from annotated image data - placeholder for future enhancement"""
122
- # This would be used with more advanced annotation tools
123
- return []
124
-
125
- def update_canvas_with_boxes(annotated_data):
126
- """Update canvas when boxes are drawn - placeholder for future enhancement"""
127
- # For now, return the current canvas
128
- return create_canvas_image(), ""
129
-
130
- def clear_bounding_boxes():
131
- """Clear all bounding boxes"""
132
- return create_canvas_image(), ""
133
 
134
  def parse_bounding_boxes(bbox_str):
135
- """
136
- Parse bounding boxes from string format
137
- Expected format: "x1,y1,x2,y2;x1,y1,x2,y2" or empty for legacy mode
138
- """
139
  if not bbox_str or not bbox_str.strip():
140
  return None
141
 
@@ -146,53 +30,46 @@ def parse_bounding_boxes(bbox_str):
146
  coords = [float(x.strip()) for x in box_str.split(',')]
147
  if len(coords) == 4:
148
  x1, y1, x2, y2 = coords
149
- # Ensure coordinates are in [0,1] range and valid
150
  x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2))
151
  y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2))
152
  boxes.append((x1, y1, x2, y2))
153
-
154
  return boxes if boxes else None
155
- except Exception as e:
156
- print(f"Error parsing bounding boxes: {e}")
157
  return None
158
 
159
- def sync_text_to_canvas(bbox_str):
160
- """Sync text input to canvas visualization"""
161
- boxes = parse_bounding_boxes(bbox_str)
162
- if boxes:
163
- return draw_boxes_on_canvas(boxes)
164
- else:
165
- return create_canvas_image()
166
-
167
 
168
- def generate_tkg_dm_image(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str):
169
  """Generate image using TKG-DM or fallback demo"""
170
 
171
  try:
172
- # Try to use actual TKG-DM pipeline with CPU fallback
173
  device = "cuda" if torch.cuda.is_available() else "cpu"
174
 
175
- # Parse bounding boxes from string input
176
- bounding_boxes = parse_bounding_boxes(bounding_boxes_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- # Initialize pipeline with selected model type and optional custom model ID
179
  model_id = custom_model_id.strip() if custom_model_id.strip() else None
180
  pipeline = TKGDMPipeline(model_id=model_id, model_type=model_type, device=device)
181
 
182
  if pipeline.pipe is not None:
183
- # Use actual pipeline with direct latent channel control
184
  channel_shifts = [ch0_shift, ch1_shift, ch2_shift, ch3_shift]
185
-
186
- # Generate with TKG-DM using direct channel shifts and user controls
187
- # Apply intensity multiplier to base shift percent
188
  final_shift_percent = shift_percent * intensity
189
-
190
- # Use blur sigma (0 means auto-calculate)
191
  blur_sigma_param = None if blur_sigma == 0 else blur_sigma
192
 
193
- # Generate with space-aware TKG-DM using bounding boxes
194
  if not bounding_boxes:
195
- # Default to center box if no boxes specified
196
  bounding_boxes = [(0.3, 0.3, 0.7, 0.7)]
197
 
198
  image = pipeline(
@@ -210,40 +87,28 @@ def generate_tkg_dm_image(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, in
210
 
211
  except Exception as e:
212
  print(f"Using demo mode due to: {e}")
213
- # Fallback to demo visualization
214
  return create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes)
215
 
216
 
217
  def create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes=None):
218
- """Create demo visualization of TKG-DM concept"""
219
-
220
- # Create image with background based on channel shifts
221
- # Convert latent channel shifts to approximate RGB for visualization
222
  approx_color = (
223
- max(0, min(255, 128 + int(ch0_shift * 127))), # Luminance -> Red
224
- max(0, min(255, 128 + int(ch1_shift * 127))), # Color1 -> Green
225
- max(0, min(255, 128 + int(ch2_shift * 127))) # Color2 -> Blue
226
  )
227
  img = Image.new('RGB', (512, 512), approx_color)
228
  draw = ImageDraw.Draw(img)
229
 
230
- # Draw space-aware bounding boxes
231
  if not bounding_boxes:
232
- # Default to center box if none specified
233
  bounding_boxes = [(0.3, 0.3, 0.7, 0.7)]
234
 
235
  for i, (x1, y1, x2, y2) in enumerate(bounding_boxes):
236
  px1, py1 = int(x1 * 512), int(y1 * 512)
237
  px2, py2 = int(x2 * 512), int(y2 * 512)
238
-
239
- # Draw bounding box with gradient effect
240
  draw.rectangle([px1, py1, px2, py2], outline='yellow', width=3)
241
- draw.rectangle([px1+2, py1+2, px2-2, py2-2], outline='orange', width=2)
242
-
243
- # Add box label
244
  draw.text((px1+5, py1+5), f"Box {i+1}", fill='white')
245
 
246
- # Add text
247
  draw.text((10, 10), f"TKG-DM Demo", fill='white')
248
  draw.text((10, 30), f"Prompt: {prompt[:40]}...", fill='white')
249
  draw.text((10, 480), f"Channels: [{ch0_shift:+.2f},{ch1_shift:+.2f},{ch2_shift:+.2f},{ch3_shift:+.2f}]", fill='white')
@@ -251,322 +116,101 @@ def create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift
251
  return img
252
 
253
 
254
- # Create intuitive interface with step-by-step workflow
255
- with gr.Blocks(title="🎨 SAWNA: Space-Aware Text-to-Image Generation", theme=gr.themes.Default()) as demo:
 
 
256
 
257
- # Header section with workflow explanation
258
  with gr.Row():
259
- with gr.Column(scale=3):
260
- gr.Markdown("""
261
- # 🎨 SAWNA: Space-Aware Text-to-Image Generation
262
 
263
- Create professional images with **guaranteed empty spaces** for headlines, logos, and product shots.
264
- Perfect for advertisements, posters, and UI mockups.
265
- """)
266
- with gr.Column(scale=2):
267
- gr.Markdown("""
268
- ### 🚀 Quick Start:
269
- 1. **Describe** your image in the text prompt
270
- 2. **Choose** where to keep empty (preset or custom)
271
- 3. **Adjust** colors and style (optional)
272
- 4. **Generate** with guaranteed reserved regions
273
- """)
274
-
275
- with gr.Row():
276
- gr.Markdown("""
277
- ---
278
- 💡 **How it works**: SAWNA uses advanced noise manipulation to suppress content generation in your specified regions,
279
- ensuring they remain empty for your design elements while maintaining high quality in other areas.
280
- """)
281
-
282
- gr.Markdown("## 🎯 Create Your Space-Aware Image")
283
-
284
- # Main workflow section
285
- with gr.Row():
286
- # Left column - Input and controls
287
- with gr.Column(scale=2):
288
 
289
- # Step 1: Text Prompt
290
- with gr.Group():
291
- gr.Markdown("## 📝 Step 1: Describe Your Image")
292
- prompt = gr.Textbox(
293
- value="A majestic lion in a natural landscape",
294
- label="Text Prompt",
295
- placeholder="Describe what you want to generate...",
296
- lines=2
297
- )
298
-
299
- # Step 2: Reserved Regions
 
 
 
 
 
 
 
300
  with gr.Group():
301
- gr.Markdown("## 🔲 Step 2: Define Empty Regions")
302
- gr.Markdown("*Choose areas that must stay empty for your design elements*")
303
-
304
- # Quick presets
305
- with gr.Row():
306
- preset_dropdown = gr.Dropdown(
307
- choices=[
308
- ("None (Default Center)", "center_box"),
309
- ("Top Banner", "top_strip"),
310
- ("Bottom Banner", "bottom_strip"),
311
- ("Side Panels", "left_right"),
312
- ("Corner Logos", "corners"),
313
- ("Full Frame", "frame")
314
- ],
315
- label="🚀 Quick Presets",
316
- value="center_box"
317
- )
318
-
319
- # Manual box creation
320
- gr.Markdown("**Or Create Custom Boxes:**")
321
- with gr.Row():
322
- with gr.Column(scale=1):
323
- x1_input = gr.Number(value=0.3, minimum=0.0, maximum=1.0, step=0.01, label="Left (X1)")
324
- x2_input = gr.Number(value=0.7, minimum=0.0, maximum=1.0, step=0.01, label="Right (X2)")
325
- with gr.Column(scale=1):
326
- y1_input = gr.Number(value=0.3, minimum=0.0, maximum=1.0, step=0.01, label="Top (Y1)")
327
- y2_input = gr.Number(value=0.7, minimum=0.0, maximum=1.0, step=0.01, label="Bottom (Y2)")
328
-
329
- with gr.Row():
330
- add_box_btn = gr.Button("➕ Add Region", variant="primary", size="sm")
331
- remove_box_btn = gr.Button("❌ Remove Last", variant="secondary", size="sm")
332
- clear_btn = gr.Button("🗑️ Clear All", variant="secondary", size="sm")
333
-
334
- # Text representation
335
- bounding_boxes_str = gr.Textbox(
336
- value="0.3,0.3,0.7,0.7",
337
- label="📋 Region Coordinates",
338
- placeholder="x1,y1,x2,y2;x1,y1,x2,y2 (auto-updated)",
339
- lines=2,
340
- info="Coordinates are normalized (0.0 = left/top, 1.0 = right/bottom)"
341
  )
342
 
343
- # Step 3: Color and Style Controls
344
- with gr.Group():
345
- gr.Markdown("## 🎨 Step 3: Fine-tune Colors")
346
- gr.Markdown("*Adjust the 4 latent channels to control image colors and style*")
347
-
348
  with gr.Row():
349
- ch0_shift = gr.Slider(-1.0, 1.0, 0.0, label="💡 Brightness", info="Overall image brightness")
350
- ch1_shift = gr.Slider(-1.0, 1.0, 1.0, label="🔵 Blue-Red Balance", info="Shift toward blue (+) or red (-)")
351
-
352
- with gr.Row():
353
- ch2_shift = gr.Slider(-1.0, 1.0, 1.0, label="🟡 Yellow-Blue Balance", info="Shift toward yellow (+) or dark blue (-)")
354
- ch3_shift = gr.Slider(-1.0, 1.0, 0.0, label="⚪ Contrast", info="Adjust overall contrast")
355
-
356
- # Right column - Preview and results
357
- with gr.Column(scale=1):
358
-
359
- # Preview section
360
- with gr.Group():
361
- gr.Markdown("## 👁️ Preview: Empty Regions")
362
- bbox_preview = gr.Image(
363
- value=create_canvas_image(),
364
- label="Reserved Regions Visualization",
365
- interactive=False,
366
- type="pil"
367
- )
368
- gr.Markdown("*Yellow boxes show where content will be suppressed*")
369
 
370
- # Advanced controls (collapsible)
371
- with gr.Accordion("🎛️ Advanced Generation Settings", open=True):
372
- with gr.Row():
373
- with gr.Column():
374
- gr.Markdown("### Generation Settings")
375
- intensity = gr.Slider(0.5, 3.0, 1.0, label="Effect Intensity", info="How strongly to suppress content in empty regions")
376
- steps = gr.Slider(10, 100, 25, label="Quality Steps", info="More steps = higher quality, slower generation")
377
-
378
- gr.Markdown("### TKG-DM Technical Controls")
379
- shift_percent = gr.Slider(0.01, 0.15, 0.07, step=0.005, label="🎯 Shift Percent", info="Base shift percentage for noise optimization (±7% default)")
380
- blur_sigma = gr.Slider(0.0, 5.0, 0.0, step=0.1, label="🌫️ Blur Sigma", info="Gaussian blur for soft transitions (0 = auto)")
381
-
382
- with gr.Column():
383
- gr.Markdown("### Model Selection")
384
- model_type = gr.Dropdown(
385
- ["sd1.5", "sdxl", "sd2.1"],
386
- value="sd1.5",
387
- label="Model Architecture",
388
- info="SDXL for highest quality, SD1.5 for speed"
389
- )
390
- custom_model_id = gr.Textbox(
391
- "",
392
- label="Custom Model (Optional)",
393
- placeholder="e.g., dreamlike-art/dreamlike-diffusion-1.0",
394
- info="Use any Hugging Face Stable Diffusion model"
395
- )
396
 
397
- # Generation section
398
- with gr.Row():
399
- with gr.Column(scale=1):
400
- generate_btn = gr.Button(
401
- "🎨 Generate Space-Aware Image",
402
- variant="primary",
403
- size="lg",
404
- elem_id="generate-btn"
405
- )
406
- gr.Markdown("*Click to create your image with guaranteed empty regions*")
407
-
408
- with gr.Column(scale=3):
409
- output_image = gr.Image(
410
- label="✨ Generated Image",
411
- type="pil",
412
- height=500,
413
- elem_id="output-image"
414
- )
415
 
416
  # Examples section
417
  with gr.Accordion("📚 Example Prompts & Layouts", open=False):
418
- gr.Markdown("""
419
- ### Try these professional design scenarios:
420
- Click any example to load it automatically and see how SAWNA handles different layout requirements.
421
- """)
422
 
423
- gr.Examples(
424
  examples=[
425
  [
426
  "A majestic lion in African savanna",
427
  0.2, 0.3, 0.0, 0.0, 1.0, 25, 0.07, 0.0, "sd1.5", "",
428
- "0.3,0.3,0.7,0.7"
429
  ],
430
  [
431
  "Modern cityscape with skyscrapers at sunset",
432
  -0.1, -0.3, 0.2, 0.1, 1.2, 30, 0.08, 0.0, "sdxl", "",
433
- "0.0,0.0,1.0,0.3"
434
  ],
435
  [
436
  "Vintage luxury car on mountain road",
437
  0.1, 0.2, -0.1, -0.2, 0.9, 25, 0.06, 0.0, "sd1.5", "",
438
- "0.0,0.7,1.0,1.0"
439
  ],
440
  [
441
  "Space astronaut floating in nebula",
442
  0.0, 0.4, -0.2, 0.3, 1.1, 35, 0.09, 1.8, "sd2.1", "",
443
- "0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8"
444
  ],
445
  [
446
- "Product photography: premium watch (fine-tuned)",
447
  0.2, 0.0, 0.1, -0.1, 1.3, 40, 0.12, 2.5, "sdxl", "",
448
- "0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8"
449
  ]
450
  ],
451
  inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift,
452
- intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str],
453
- label="Professional Use Cases"
454
  )
455
-
456
- # Add custom CSS for better styling
457
- demo.load(fn=None, js="""
458
- function() {
459
- // Add custom styling
460
- const style = document.createElement('style');
461
- style.textContent = `
462
- .gradio-container {
463
- max-width: 1400px !important;
464
- margin: auto;
465
- }
466
-
467
- #generate-btn {
468
- background: linear-gradient(45deg, #7c3aed, #a855f7) !important;
469
- border: none !important;
470
- font-weight: bold !important;
471
- padding: 15px 30px !important;
472
- font-size: 16px !important;
473
- }
474
-
475
- #output-image {
476
- border-radius: 12px !important;
477
- box-shadow: 0 8px 32px rgba(0,0,0,0.1) !important;
478
- }
479
-
480
- .gr-group {
481
- border-radius: 12px !important;
482
- border: 1px solid #e5e7eb !important;
483
- padding: 20px !important;
484
- margin-bottom: 20px !important;
485
- }
486
-
487
- .gr-accordion {
488
- border-radius: 8px !important;
489
- border: 1px solid #d1d5db !important;
490
- }
491
- `;
492
- document.head.appendChild(style);
493
- return [];
494
- }
495
- """)
496
-
497
- # Event handlers
498
- def generate_wrapper(*args):
499
- return generate_tkg_dm_image(*args)
500
-
501
- def clear_boxes_handler():
502
- """Clear boxes and update preview"""
503
- return "", create_canvas_image()
504
-
505
- def update_preview_from_text(bbox_str):
506
- """Update preview image from text input"""
507
- return sync_text_to_canvas(bbox_str)
508
-
509
- def add_box_handler(bbox_str, x1, y1, x2, y2):
510
- """Add a new box and update preview"""
511
- updated_str, preview_img = add_bounding_box(bbox_str, x1, y1, x2, y2)
512
- return updated_str, preview_img
513
-
514
- def remove_box_handler(bbox_str):
515
- """Remove last box and update preview"""
516
- return remove_last_box(bbox_str)
517
-
518
- def load_preset_handler(preset_name):
519
- """Load preset boxes and update preview"""
520
- if preset_name and preset_name != "center_box": # Don't reload default
521
- preset_str = load_preset_boxes(preset_name)
522
- return preset_str, sync_text_to_canvas(preset_str)
523
- elif preset_name == "center_box":
524
- preset_str = "0.3,0.3,0.7,0.7"
525
- return preset_str, sync_text_to_canvas(preset_str)
526
- return "", create_canvas_image()
527
-
528
- # Preset dropdown
529
- preset_dropdown.change(
530
- fn=load_preset_handler,
531
- inputs=[preset_dropdown],
532
- outputs=[bounding_boxes_str, bbox_preview]
533
- )
534
-
535
- # Add box button
536
- add_box_btn.click(
537
- fn=add_box_handler,
538
- inputs=[bounding_boxes_str, x1_input, y1_input, x2_input, y2_input],
539
- outputs=[bounding_boxes_str, bbox_preview]
540
- )
541
-
542
- # Remove last box button
543
- remove_box_btn.click(
544
- fn=remove_box_handler,
545
- inputs=[bounding_boxes_str],
546
- outputs=[bounding_boxes_str, bbox_preview]
547
- )
548
-
549
- # Clear all boxes button
550
- clear_btn.click(
551
- fn=clear_boxes_handler,
552
- outputs=[bounding_boxes_str, bbox_preview]
553
- )
554
-
555
- # Sync text to preview canvas
556
- bounding_boxes_str.change(
557
- fn=update_preview_from_text,
558
- inputs=[bounding_boxes_str],
559
- outputs=[bbox_preview]
560
- )
561
-
562
- # Generate button
563
- generate_btn.click(
564
- fn=generate_wrapper,
565
- inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift,
566
- intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str],
567
- outputs=[output_image]
568
- )
569
-
570
 
571
  if __name__ == "__main__":
572
- demo.launch(share=True)
 
2
  import torch
3
  import numpy as np
4
  from PIL import Image, ImageDraw
 
5
  from tkg_dm import TKGDMPipeline
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def load_preset_boxes(preset_name):
9
  """Load preset bounding box configurations"""
10
  presets = {
 
17
  }
18
  return presets.get(preset_name, "")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def parse_bounding_boxes(bbox_str):
22
+ """Parse bounding boxes from string format"""
 
 
 
23
  if not bbox_str or not bbox_str.strip():
24
  return None
25
 
 
30
  coords = [float(x.strip()) for x in box_str.split(',')]
31
  if len(coords) == 4:
32
  x1, y1, x2, y2 = coords
 
33
  x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2))
34
  y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2))
35
  boxes.append((x1, y1, x2, y2))
 
36
  return boxes if boxes else None
37
+ except:
 
38
  return None
39
 
 
 
 
 
 
 
 
 
40
 
41
+ def generate_tkg_dm_image(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str, preset, x1, y1, x2, y2):
42
  """Generate image using TKG-DM or fallback demo"""
43
 
44
  try:
 
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
47
+ # Handle preset loading and manual box addition
48
+ final_bbox_str = bounding_boxes_str
49
+ if preset and preset != "center_box":
50
+ preset_str = load_preset_boxes(preset)
51
+ if preset_str:
52
+ final_bbox_str = preset_str
53
+
54
+ # Add manual box if coordinates are provided and different from default
55
+ if not (x1 == 0.3 and y1 == 0.3 and x2 == 0.7 and y2 == 0.7):
56
+ manual_box = f"{x1:.3f},{y1:.3f},{x2:.3f},{y2:.3f}"
57
+ if final_bbox_str.strip():
58
+ final_bbox_str += ";" + manual_box
59
+ else:
60
+ final_bbox_str = manual_box
61
+
62
+ bounding_boxes = parse_bounding_boxes(final_bbox_str)
63
 
 
64
  model_id = custom_model_id.strip() if custom_model_id.strip() else None
65
  pipeline = TKGDMPipeline(model_id=model_id, model_type=model_type, device=device)
66
 
67
  if pipeline.pipe is not None:
 
68
  channel_shifts = [ch0_shift, ch1_shift, ch2_shift, ch3_shift]
 
 
 
69
  final_shift_percent = shift_percent * intensity
 
 
70
  blur_sigma_param = None if blur_sigma == 0 else blur_sigma
71
 
 
72
  if not bounding_boxes:
 
73
  bounding_boxes = [(0.3, 0.3, 0.7, 0.7)]
74
 
75
  image = pipeline(
 
87
 
88
  except Exception as e:
89
  print(f"Using demo mode due to: {e}")
 
90
  return create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes)
91
 
92
 
93
  def create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes=None):
94
+ """Create demo visualization"""
 
 
 
95
  approx_color = (
96
+ max(0, min(255, 128 + int(ch0_shift * 127))),
97
+ max(0, min(255, 128 + int(ch1_shift * 127))),
98
+ max(0, min(255, 128 + int(ch2_shift * 127)))
99
  )
100
  img = Image.new('RGB', (512, 512), approx_color)
101
  draw = ImageDraw.Draw(img)
102
 
 
103
  if not bounding_boxes:
 
104
  bounding_boxes = [(0.3, 0.3, 0.7, 0.7)]
105
 
106
  for i, (x1, y1, x2, y2) in enumerate(bounding_boxes):
107
  px1, py1 = int(x1 * 512), int(y1 * 512)
108
  px2, py2 = int(x2 * 512), int(y2 * 512)
 
 
109
  draw.rectangle([px1, py1, px2, py2], outline='yellow', width=3)
 
 
 
110
  draw.text((px1+5, py1+5), f"Box {i+1}", fill='white')
111
 
 
112
  draw.text((10, 10), f"TKG-DM Demo", fill='white')
113
  draw.text((10, 30), f"Prompt: {prompt[:40]}...", fill='white')
114
  draw.text((10, 480), f"Channels: [{ch0_shift:+.2f},{ch1_shift:+.2f},{ch2_shift:+.2f},{ch3_shift:+.2f}]", fill='white')
 
116
  return img
117
 
118
 
119
+ # Create Gradio 5.x compatible interface with improved syntax
120
+ with gr.Blocks(title="🎨 SAWNA: Space-Aware Text-to-Image Generation") as demo:
121
+ gr.Markdown("# 🎨 SAWNA: Space-Aware Text-to-Image Generation")
122
+ gr.Markdown("Generate images with precise background control using space-aware noise optimization.")
123
 
 
124
  with gr.Row():
125
+ with gr.Column():
126
+ prompt = gr.Textbox(value="A majestic lion", label="Prompt")
 
127
 
128
+ with gr.Row():
129
+ ch0_shift = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, label="Channel 0 (Luminance/Color)")
130
+ ch1_shift = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, label="Channel 1 (Pink/Yellow+, Red/Blue-)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ with gr.Row():
133
+ ch2_shift = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, label="Channel 2 (Pink/Yellow+, Red/Blue-)")
134
+ ch3_shift = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, label="Channel 3 (Luminance/Color)")
135
+
136
+ with gr.Row():
137
+ intensity = gr.Slider(minimum=0.5, maximum=3.0, value=1.0, label="Shift Intensity")
138
+ steps = gr.Slider(minimum=10, maximum=100, value=25, label="Steps")
139
+
140
+ with gr.Row():
141
+ shift_percent = gr.Slider(minimum=0.01, maximum=0.15, value=0.07, label="Shift Percent")
142
+ blur_sigma = gr.Slider(minimum=0.0, maximum=5.0, value=0.0, label="Blur Sigma (0=auto)")
143
+
144
+ model_type = gr.Dropdown(choices=["sd1.5", "sdxl", "sd2.1"], value="sd1.5", label="Model Type")
145
+ custom_model_id = gr.Textbox(value="", label="Custom Model ID (optional)")
146
+ bounding_boxes_str = gr.Textbox(value="0.3,0.3,0.7,0.7", label="Bounding Boxes",
147
+ placeholder="x1,y1,x2,y2;x1,y1,x2,y2")
148
+
149
+ # Box building controls
150
  with gr.Group():
151
+ gr.Markdown("### Box Building Controls")
152
+ preset = gr.Dropdown(
153
+ choices=["center_box", "top_strip", "bottom_strip", "left_right", "corners", "frame"],
154
+ value="center_box",
155
+ label="Quick Presets"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  )
157
 
 
 
 
 
 
158
  with gr.Row():
159
+ x1 = gr.Number(value=0.3, minimum=0, maximum=1, label="Box X1")
160
+ y1 = gr.Number(value=0.3, minimum=0, maximum=1, label="Box Y1")
161
+ x2 = gr.Number(value=0.7, minimum=0, maximum=1, label="Box X2")
162
+ y2 = gr.Number(value=0.7, minimum=0, maximum=1, label="Box Y2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ generate_btn = gr.Button("Generate Image", variant="primary")
165
+
166
+ with gr.Column():
167
+ output_image = gr.Image(label="Generated Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ # Event handler
170
+ generate_btn.click(
171
+ fn=generate_tkg_dm_image,
172
+ inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift,
173
+ intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str,
174
+ preset, x1, y1, x2, y2],
175
+ outputs=output_image
176
+ )
 
 
 
 
 
 
 
 
 
 
177
 
178
  # Examples section
179
  with gr.Accordion("📚 Example Prompts & Layouts", open=False):
180
+ gr.Markdown("### Try these professional design scenarios:")
 
 
 
181
 
182
+ examples = gr.Examples(
183
  examples=[
184
  [
185
  "A majestic lion in African savanna",
186
  0.2, 0.3, 0.0, 0.0, 1.0, 25, 0.07, 0.0, "sd1.5", "",
187
+ "0.3,0.3,0.7,0.7", "center_box", 0.3, 0.3, 0.7, 0.7
188
  ],
189
  [
190
  "Modern cityscape with skyscrapers at sunset",
191
  -0.1, -0.3, 0.2, 0.1, 1.2, 30, 0.08, 0.0, "sdxl", "",
192
+ "0.0,0.0,1.0,0.3", "top_strip", 0.0, 0.0, 1.0, 0.3
193
  ],
194
  [
195
  "Vintage luxury car on mountain road",
196
  0.1, 0.2, -0.1, -0.2, 0.9, 25, 0.06, 0.0, "sd1.5", "",
197
+ "0.0,0.7,1.0,1.0", "bottom_strip", 0.0, 0.7, 1.0, 1.0
198
  ],
199
  [
200
  "Space astronaut floating in nebula",
201
  0.0, 0.4, -0.2, 0.3, 1.1, 35, 0.09, 1.8, "sd2.1", "",
202
+ "0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8", "left_right", 0.0, 0.2, 0.3, 0.8
203
  ],
204
  [
205
+ "Product photography: premium watch",
206
  0.2, 0.0, 0.1, -0.1, 1.3, 40, 0.12, 2.5, "sdxl", "",
207
+ "0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8", "frame", 0.0, 0.0, 1.0, 0.2
208
  ]
209
  ],
210
  inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift,
211
+ intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str,
212
+ preset, x1, y1, x2, y2]
213
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  if __name__ == "__main__":
216
+ demo.launch(share=True, server_name="0.0.0.0")
requirements.txt CHANGED
@@ -4,10 +4,10 @@ diffusers>=0.21.0
4
  transformers>=4.25.0
5
  accelerate>=0.20.0
6
  safetensors>=0.3.0
7
- gradio==3.50.2
8
  pillow>=9.0.0
9
  numpy>=1.21.0
10
  scipy>=1.7.0
11
  ftfy>=6.1.0
12
  regex>=2022.0.0
13
- requests>=2.25.0
 
4
  transformers>=4.25.0
5
  accelerate>=0.20.0
6
  safetensors>=0.3.0
7
+ gradio==5.34.1
8
  pillow>=9.0.0
9
  numpy>=1.21.0
10
  scipy>=1.7.0
11
  ftfy>=6.1.0
12
  regex>=2022.0.0
13
+ requests>=2.25.0