snackshell commited on
Commit
5ddcd4f
·
verified ·
1 Parent(s): 4103aa4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -71
app.py CHANGED
@@ -5,10 +5,13 @@ from PIL import Image, ImageDraw, ImageFont
5
  import torch
6
  from diffusers import DiffusionPipeline
7
  import io
 
8
 
9
  # ===== CONFIG =====
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
 
 
12
  model_repo_id = "stabilityai/sdxl-turbo"
13
  pipe = DiffusionPipeline.from_pretrained(
14
  model_repo_id,
@@ -17,108 +20,161 @@ pipe = DiffusionPipeline.from_pretrained(
17
  )
18
  pipe.to(device)
19
 
 
 
 
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
- IMAGE_WIDTH = 768
22
- IMAGE_HEIGHT = 768
23
  WATERMARK_TEXT = "SelamGPT"
24
 
25
- # ===== WATERMARK FUNCTION =====
26
  def add_watermark(image):
27
- draw = ImageDraw.Draw(image)
28
- font_size = int(image.width * 0.03)
29
  try:
30
- font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
31
- except:
32
- font = ImageFont.load_default()
33
- text_width = draw.textlength(WATERMARK_TEXT, font=font)
34
- x = image.width - text_width - 12
35
- y = image.height - font_size - 10
36
- draw.text((x+1, y+1), WATERMARK_TEXT, font=font, fill=(0, 0, 0, 128))
37
- draw.text((x, y), WATERMARK_TEXT, font=font, fill=(255, 255, 255))
38
- return image
39
-
40
- # ===== INFERENCE FUNCTION =====
 
 
 
 
 
 
 
 
 
 
 
41
  def generate(
42
  prompt,
43
- negative_prompt,
44
- seed,
45
- randomize_seed,
46
- guidance_scale,
47
- num_inference_steps,
48
  progress=gr.Progress(track_tqdm=True),
49
  ):
50
  if not prompt.strip():
51
  return None, "⚠️ Please enter a prompt"
52
 
53
- if randomize_seed:
 
 
 
54
  seed = random.randint(0, MAX_SEED)
55
-
56
  generator = torch.manual_seed(seed)
 
 
57
  result = pipe(
58
  prompt=prompt,
59
  negative_prompt=negative_prompt,
60
- width=IMAGE_WIDTH,
61
- height=IMAGE_HEIGHT,
62
  guidance_scale=guidance_scale,
63
- num_inference_steps=num_inference_steps,
64
  generator=generator,
65
  ).images[0]
66
-
 
67
  watermarked = add_watermark(result)
68
  buffer = io.BytesIO()
69
- watermarked.convert("RGB").save(buffer, format="JPEG", quality=70)
70
  buffer.seek(0)
71
- return Image.open(buffer), seed
 
 
 
 
72
 
73
  # ===== EXAMPLES =====
74
  examples = [
75
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
76
- "An astronaut riding a green horse",
77
- "A delicious ceviche cheesecake slice",
78
  ]
79
 
80
- # ===== INTERFACE =====
81
- css = "#container { max-width: 700px; margin: auto; }"
82
-
83
- with gr.Blocks(css=css, title="SelamGPT Turbo Generator") as demo:
84
- with gr.Column(elem_id="container"):
85
- gr.Markdown("# 🖼️ SelamGPT Image Generator")
86
 
87
- with gr.Row():
 
 
 
 
 
 
 
88
  prompt = gr.Textbox(
89
- label="Prompt",
90
- show_label=False,
91
- placeholder="Enter your prompt",
92
- lines=1,
93
- scale=3
94
  )
95
- generate_btn = gr.Button("Generate", variant="primary")
96
-
97
- output_image = gr.Image(label="Generated Image", type="pil", format="jpeg")
98
- seed_display = gr.Textbox(label="Seed Used", interactive=False)
99
-
100
- with gr.Accordion("⚙️ Advanced Settings", open=False):
101
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="What to avoid (optional)", max_lines=1)
102
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
103
- seed = gr.Slider(0, MAX_SEED, step=1, label="Seed", value=0)
104
-
105
- guidance_scale = gr.Slider(0.0, 10.0, step=0.1, label="Guidance Scale", value=0.0)
106
- num_inference_steps = gr.Slider(1, 10, step=1, label="Inference Steps", value=2)
107
-
108
- gr.Examples(examples=examples, inputs=[prompt])
109
-
110
- generate_btn.click(
111
- fn=generate,
112
- inputs=[
113
- prompt,
114
- negative_prompt,
115
- seed,
116
- randomize_seed,
117
- guidance_scale,
118
- num_inference_steps
119
- ],
120
- outputs=[output_image, seed_display]
121
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  if __name__ == "__main__":
124
- demo.launch()
 
 
5
  import torch
6
  from diffusers import DiffusionPipeline
7
  import io
8
+ import time
9
 
10
  # ===== CONFIG =====
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
13
+
14
+ # Using SDXL Turbo for fastest generation
15
  model_repo_id = "stabilityai/sdxl-turbo"
16
  pipe = DiffusionPipeline.from_pretrained(
17
  model_repo_id,
 
20
  )
21
  pipe.to(device)
22
 
23
+ # Enable memory efficient attention and channels last for better performance
24
+ pipe.enable_xformers_memory_efficient_attention()
25
+ pipe.unet.to(memory_format=torch.channels_last)
26
+
27
  MAX_SEED = np.iinfo(np.int32).max
28
+ IMAGE_SIZE = 1024 # Same as original code
 
29
  WATERMARK_TEXT = "SelamGPT"
30
 
31
+ # ===== OPTIMIZED WATERMARK FUNCTION =====
32
  def add_watermark(image):
33
+ """Optimized watermark function matching original style"""
 
34
  try:
35
+ draw = ImageDraw.Draw(image)
36
+ font_size = 24 # Fixed size as in original
37
+
38
+ try:
39
+ font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
40
+ except:
41
+ font = ImageFont.load_default(font_size)
42
+
43
+ text_width = draw.textlength(WATERMARK_TEXT, font=font)
44
+ x = image.width - text_width - 10
45
+ y = image.height - 34
46
+
47
+ # Shadow effect
48
+ draw.text((x+1, y+1), WATERMARK_TEXT, font=font, fill=(0, 0, 0, 128))
49
+ draw.text((x, y), WATERMARK_TEXT, font=font, fill=(255, 255, 255))
50
+
51
+ return image
52
+ except Exception as e:
53
+ print(f"Watermark error: {str(e)}")
54
+ return image
55
+
56
+ # ===== ULTRA-FAST INFERENCE FUNCTION =====
57
  def generate(
58
  prompt,
59
+ negative_prompt="",
60
+ seed=None,
61
+ randomize_seed=True,
62
+ guidance_scale=0.0, # 0.0 for turbo models
63
+ num_inference_steps=1, # Can be as low as 1-2 for turbo
64
  progress=gr.Progress(track_tqdm=True),
65
  ):
66
  if not prompt.strip():
67
  return None, "⚠️ Please enter a prompt"
68
 
69
+ start_time = time.time()
70
+
71
+ # Seed handling
72
+ if randomize_seed or seed is None:
73
  seed = random.randint(0, MAX_SEED)
74
+
75
  generator = torch.manual_seed(seed)
76
+
77
+ # Ultra-fast generation with minimal steps
78
  result = pipe(
79
  prompt=prompt,
80
  negative_prompt=negative_prompt,
81
+ width=IMAGE_SIZE,
82
+ height=IMAGE_SIZE,
83
  guidance_scale=guidance_scale,
84
+ num_inference_steps=max(1, num_inference_steps), # Minimum 1 step
85
  generator=generator,
86
  ).images[0]
87
+
88
+ # Optimized watermark and JPG conversion
89
  watermarked = add_watermark(result)
90
  buffer = io.BytesIO()
91
+ watermarked.save(buffer, format="JPEG", quality=85, optimize=True)
92
  buffer.seek(0)
93
+
94
+ gen_time = time.time() - start_time
95
+ status = f"✔️ Generated in {gen_time:.2f}s | Seed: {seed}"
96
+
97
+ return Image.open(buffer), status
98
 
99
  # ===== EXAMPLES =====
100
  examples = [
101
+ ["An ancient Aksumite warrior in cyberpunk armor, 4k detailed"],
102
+ ["Traditional Ethiopian coffee ceremony in zero gravity"],
103
+ ["Portrait of a Habesha queen with golden jewelry"]
104
  ]
105
 
106
+ # ===== OPTIMIZED INTERFACE =====
107
+ theme = gr.themes.Default(
108
+ primary_hue="emerald",
109
+ secondary_hue="amber",
110
+ font=[gr.themes.GoogleFont("Poppins"), "Arial", "sans-serif"]
111
+ )
112
 
113
+ with gr.Blocks(theme=theme, title="SelamGPT Turbo Generator") as demo:
114
+ gr.Markdown("""
115
+ # 🎨 SelamGPT Turbo Image Generator
116
+ *Ultra-fast 1024x1024 image generation with SDXL-Turbo*
117
+ """)
118
+
119
+ with gr.Row():
120
+ with gr.Column(scale=3):
121
  prompt = gr.Textbox(
122
+ label="Describe your image",
123
+ placeholder="A futuristic Ethiopian city with flying cars...",
124
+ lines=3,
125
+ max_lines=5
 
126
  )
127
+ with gr.Row():
128
+ generate_btn = gr.Button("Generate Image", variant="primary")
129
+ clear_btn = gr.Button("Clear")
130
+
131
+ gr.Examples(
132
+ examples=examples,
133
+ inputs=[prompt]
134
+ )
135
+
136
+ with gr.Column(scale=2):
137
+ output_image = gr.Image(
138
+ label="Generated Image",
139
+ type="pil",
140
+ format="jpeg",
141
+ height=512
142
+ )
143
+ status_output = gr.Textbox(
144
+ label="Status",
145
+ interactive=False
146
+ )
147
+
148
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
149
+ negative_prompt = gr.Textbox(
150
+ label="Negative Prompt",
151
+ placeholder="What to avoid (optional)",
152
+ max_lines=1
153
  )
154
+ with gr.Row():
155
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
156
+ seed = gr.Number(label="Seed", value=0, precision=0)
157
+ guidance_scale = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Guidance Scale")
158
+ num_inference_steps = gr.Slider(1, 4, value=1, step=1, label="Inference Steps")
159
+
160
+ generate_btn.click(
161
+ fn=generate,
162
+ inputs=[
163
+ prompt,
164
+ negative_prompt,
165
+ seed,
166
+ randomize_seed,
167
+ guidance_scale,
168
+ num_inference_steps
169
+ ],
170
+ outputs=[output_image, status_output]
171
+ )
172
+
173
+ clear_btn.click(
174
+ fn=lambda: [None, ""],
175
+ outputs=[output_image, status_output]
176
+ )
177
 
178
  if __name__ == "__main__":
179
+ demo.queue(max_size=4) # Increased queue for better throughput
180
+ demo.launch(server_name="0.0.0.0", server_port=7860)