dielz commited on
Commit
dd09c30
·
verified ·
1 Parent(s): df1f5d0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +594 -0
app.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ from torchvision import transforms
10
+
11
+ from diffusers import FluxFillPipeline, FluxTransformer2DModel
12
+ from diffusers.utils import check_min_version, load_image
13
+
14
+ WEIGHT_PATH = "dielz/textfux-test/transformer"
15
+ # scheduler = "overshoot" # overshoot or default
16
+ scheduler = "default"
17
+
18
+
19
+ def read_words_from_text(input_text):
20
+ """
21
+ Reads words/list of words:
22
+ - If input_text is a file path, it reads all non-empty lines from the file.
23
+ - Otherwise, it directly splits the input by newlines into a list.
24
+ """
25
+ if isinstance(input_text, str) and os.path.exists(input_text):
26
+ with open(input_text, 'r', encoding='utf-8') as f:
27
+ words = [line.strip() for line in f if line.strip()]
28
+ else:
29
+ words = [line.strip() for line in input_text.splitlines() if line.strip()]
30
+ return words
31
+
32
+ def generate_prompt(words):
33
+ words_str = ', '.join(f"'{word}'" for word in words)
34
+ prompt_template = (
35
+ "The pair of images highlights some white words on a black background, as well as their style on a real-world scene image. "
36
+ "[IMAGE1] is a template image rendering the text, with the words {words}; "
37
+ "[IMAGE2] shows the text content {words} naturally and correspondingly integrated into the image."
38
+ )
39
+ return prompt_template.format(words=words_str)
40
+
41
+ prompt_template2 = (
42
+ "The pair of images highlights some white words on a black background, as well as their style on a real-world scene image. "
43
+ "[IMAGE1] is a template image rendering the text, with the words; "
44
+ "[IMAGE2] shows the text content naturally and correspondingly integrated into the image."
45
+ )
46
+
47
+ PIPE = None
48
+ def load_flux_pipeline():
49
+ global PIPE
50
+ if PIPE is None:
51
+ transformer = FluxTransformer2DModel.from_pretrained(
52
+ WEIGHT_PATH,
53
+ torch_dtype=torch.bfloat16
54
+ )
55
+ PIPE = FluxFillPipeline.from_pretrained(
56
+ "black-forest-labs/FLUX.1-Fill-dev",
57
+ transformer=transformer,
58
+ torch_dtype=torch.bfloat16
59
+ ).to("cuda")
60
+ PIPE.transformer.to(torch.bfloat16)
61
+ return PIPE
62
+
63
+ def run_inference(image_input, mask_input, words_input, num_steps=50, guidance_scale=30, seed=42):
64
+ """
65
+ Invokes the Flux model pipeline for inference:
66
+ - Both image_input and mask_input are required to be concatenated composite images.
67
+ - Automatically adjusts image dimensions to be multiples of 32 to meet model input requirements.
68
+ - Generates a prompt based on the word list and passes it to the pipeline for inference execution.
69
+ """
70
+ if isinstance(image_input, str):
71
+ inpaint_image = load_image(image_input).convert("RGB")
72
+ else:
73
+ inpaint_image = image_input.convert("RGB")
74
+ if isinstance(mask_input, str):
75
+ extended_mask = load_image(mask_input).convert("RGB")
76
+ else:
77
+ extended_mask = mask_input.convert("RGB")
78
+ width, height = inpaint_image.size
79
+ new_width = (width // 32) * 32
80
+ new_height = (height // 32) * 32
81
+ inpaint_image = inpaint_image.resize((new_width, new_height))
82
+ extended_mask = extended_mask.resize((new_width, new_height))
83
+ words = read_words_from_text(words_input)
84
+ prompt = generate_prompt(words)
85
+ print("Generated prompt:", prompt)
86
+ transform = transforms.Compose([
87
+ transforms.ToTensor(),
88
+ transforms.Normalize([0.5], [0.5])
89
+ ])
90
+ mask_transform = transforms.Compose([
91
+ transforms.ToTensor()
92
+ ])
93
+ image_tensor = transform(inpaint_image)
94
+ mask_tensor = mask_transform(extended_mask)
95
+ generator = torch.Generator(device="cuda").manual_seed(int(seed))
96
+ pipe = load_flux_pipeline()
97
+
98
+ if scheduler == "overshoot":
99
+ try:
100
+ from diffusers import StochasticRFOvershotDiscreteScheduler
101
+ scheduler_config = pipe.scheduler.config
102
+ scheduler = StochasticRFOvershotDiscreteScheduler.from_config(scheduler_config)
103
+ overshot_func = lambda t, dt: t + dt
104
+
105
+ pipe.scheduler = scheduler
106
+ pipe.scheduler.set_c(2.0)
107
+ pipe.scheduler.set_overshot_func(overshot_func)
108
+ except ImportError:
109
+ print("StochasticRFOvershotDiscreteScheduler not found. Please ensure you have used the repo's diffusers.")
110
+ pass
111
+
112
+ result = pipe(
113
+ height=new_height,
114
+ width=new_width,
115
+ image=inpaint_image,
116
+ mask_image=extended_mask,
117
+ num_inference_steps=num_steps,
118
+ generator=generator,
119
+ max_sequence_length=512,
120
+ guidance_scale=guidance_scale,
121
+ prompt=prompt_template2,
122
+ prompt_2=prompt,
123
+ ).images[0]
124
+
125
+ return result
126
+
127
+ # =============================================================================
128
+ # Normal Mode: Direct Inference Call
129
+ # =============================================================================
130
+ def flux_demo_normal(image, mask, words, steps, guidance_scale, seed):
131
+ """
132
+ Gradio main function for normal mode:
133
+ - Directly passes the input image, mask, and word list to run_inference for inference.
134
+ - Returns the generated result image.
135
+ """
136
+ result = run_inference(image, mask, words, num_steps=steps, guidance_scale=guidance_scale, seed=seed)
137
+ return result
138
+
139
+ # =============================================================================
140
+ # Helper functions for both single-line and multi-line rendering
141
+ # =============================================================================
142
+ def extract_mask(original, drawn, threshold=30):
143
+ """
144
+ Extracts a binary mask from the original image and the user-drawn image:
145
+ - If 'drawn' is a dictionary and contains a "mask" key, that mask is directly binarized.
146
+ - Otherwise, the mask is extracted using inversion and differentiation methods.
147
+ """
148
+ if isinstance(drawn, dict):
149
+ if "mask" in drawn and drawn["mask"] is not None:
150
+ drawn_mask = np.array(drawn["mask"]).astype(np.uint8)
151
+ if drawn_mask.ndim == 3:
152
+ drawn_mask = cv2.cvtColor(drawn_mask, cv2.COLOR_RGB2GRAY)
153
+ _, binary_mask = cv2.threshold(drawn_mask, 50, 255, cv2.THRESH_BINARY)
154
+ return Image.fromarray(binary_mask).convert("RGB")
155
+ else:
156
+ drawn_img = np.array(drawn["image"]).astype(np.uint8)
157
+ drawn = 255 - drawn_img
158
+ orig_arr = np.array(original).astype(np.int16)
159
+ drawn_arr = np.array(drawn).astype(np.int16)
160
+ diff = np.abs(drawn_arr - orig_arr)
161
+ diff_gray = np.mean(diff, axis=-1)
162
+ binary_mask = (diff_gray > threshold).astype(np.uint8) * 255
163
+ return Image.fromarray(binary_mask).convert("RGB")
164
+
165
+ def get_next_seq_number():
166
+ """
167
+ Finds the next available sequential number (format: 0001, 0002,...) in the 'outputs_my' directory.
168
+ When 'result_XXXX.png' does not exist, that number is considered available, and the formatted string XXXX is returned.
169
+ """
170
+ counter = 1
171
+ while True:
172
+ seq_str = f"{counter:04d}"
173
+ result_path = os.path.join("outputs_my", f"result_{seq_str}.png")
174
+ if not os.path.exists(result_path):
175
+ return seq_str
176
+ counter += 1
177
+
178
+ # =============================================================================
179
+ # Single-line text rendering functions
180
+ # =============================================================================
181
+ def draw_glyph_flexible(font, text, width, height, max_font_size=140):
182
+ """
183
+ Renders text horizontally centered on a canvas of specified size and returns a PIL Image.
184
+ Font size is automatically adjusted to fit the canvas and is limited by max_font_size.
185
+ """
186
+ img = Image.new(mode='RGB', size=(width, height), color='black')
187
+ if not text or not text.strip():
188
+ return img
189
+ draw = ImageDraw.Draw(img)
190
+
191
+ # Initial font size for calculating scale ratio
192
+ g_size = 50
193
+ try:
194
+ new_font = font.font_variant(size=g_size)
195
+ except:
196
+ new_font = font
197
+
198
+ left, top, right, bottom = new_font.getbbox(text)
199
+ text_width_initial = max(right - left, 1)
200
+ text_height_initial = max(bottom - top, 1)
201
+
202
+ # Calculate scale ratios based on width and height
203
+ width_ratio = width * 0.9 / text_width_initial
204
+ height_ratio = height * 0.9 / text_height_initial
205
+ ratio = min(width_ratio, height_ratio)
206
+
207
+ # Adjust maximum font size based on original image width
208
+ if width > 1280:
209
+ max_font_size = 200
210
+ final_font_size = int(g_size * ratio)
211
+ final_font_size = min(final_font_size, max_font_size) # Apply upper limit
212
+
213
+ # Use the final calculated font size
214
+ try:
215
+ final_font = font.font_variant(size=max(final_font_size, 10))
216
+ except:
217
+ final_font = font
218
+
219
+ draw.text((width / 2, height / 2), text, font=final_font, fill='white', anchor='mm')
220
+ return img
221
+
222
+ # =============================================================================
223
+ # Multi-line text rendering functions
224
+ # =============================================================================
225
+ def insert_spaces(text, num_spaces):
226
+ """
227
+ Inserts a specified number of spaces between each character to adjust the spacing during text rendering.
228
+ """
229
+ if len(text) <= 1:
230
+ return text
231
+ return (' ' * num_spaces).join(list(text))
232
+
233
+
234
+ def draw_glyph2(
235
+ font,
236
+ text,
237
+ polygon,
238
+ vertAng=10,
239
+ scale=1,
240
+ width=512,
241
+ height=512,
242
+ add_space=True,
243
+ scale_factor=2,
244
+ rotate_resample=Image.BICUBIC,
245
+ downsample_resample=Image.Resampling.LANCZOS
246
+ ):
247
+ big_w = width * scale_factor
248
+ big_h = height * scale_factor
249
+
250
+ big_polygon = polygon * scale_factor * scale
251
+ rect = cv2.minAreaRect(big_polygon.astype(np.float32))
252
+ box = cv2.boxPoints(rect)
253
+ box = np.intp(box)
254
+
255
+ w, h = rect[1]
256
+ angle = rect[2]
257
+ if angle < -45:
258
+ angle += 90
259
+ angle = -angle
260
+ if w < h:
261
+ angle += 90
262
+
263
+ vert = False
264
+ if (abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng):
265
+ _w = max(box[:, 0]) - min(box[:, 0])
266
+ _h = max(box[:, 1]) - min(box[:, 1])
267
+ if _h >= _w:
268
+ vert = True
269
+ angle = 0
270
+
271
+ big_img = Image.new("RGBA", (big_w, big_h), (0, 0, 0, 0))
272
+ tmp = Image.new("RGB", big_img.size, "white")
273
+ tmp_draw = ImageDraw.Draw(tmp)
274
+
275
+ _, _, _tw, _th = tmp_draw.textbbox((0, 0), text, font=font)
276
+ if _th == 0:
277
+ text_w = 0
278
+ else:
279
+ w_f, h_f = float(w), float(h)
280
+ text_w = min(w_f, h_f) * (_tw / _th)
281
+
282
+ if text_w <= max(w, h):
283
+ if len(text) > 1 and not vert and add_space:
284
+ for i in range(1, 100):
285
+ text_sp = insert_spaces(text, i)
286
+ _, _, tw2, th2 = tmp_draw.textbbox((0, 0), text_sp, font=font)
287
+ if th2 != 0:
288
+ if min(w, h) * (tw2 / th2) > max(w, h):
289
+ break
290
+ text = insert_spaces(text, i-1)
291
+ font_size = min(w, h) * 0.80
292
+ else:
293
+ shrink = 0.75 if vert else 0.85
294
+ if text_w != 0:
295
+ font_size = min(w, h) / (text_w / max(w, h)) * shrink
296
+ else:
297
+ font_size = min(w, h) * 0.80
298
+
299
+ new_font = font.font_variant(size=int(font_size))
300
+ left, top, right, bottom = new_font.getbbox(text)
301
+ text_width = right - left
302
+ text_height = bottom - top
303
+
304
+ layer = Image.new("RGBA", big_img.size, (0, 0, 0, 0))
305
+ draw_layer = ImageDraw.Draw(layer)
306
+ cx, cy = rect[0]
307
+ if not vert:
308
+ draw_layer.text(
309
+ (cx - text_width // 2, cy - text_height // 2 - top),
310
+ text,
311
+ font=new_font,
312
+ fill=(255, 255, 255, 255)
313
+ )
314
+ else:
315
+ _w_ = max(box[:, 0]) - min(box[:, 0])
316
+ x_s = min(box[:, 0]) + _w_ // 2 - text_height // 2
317
+ y_s = min(box[:, 1])
318
+ for c in text:
319
+ draw_layer.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
320
+ _, _t, _, _b = new_font.getbbox(c)
321
+ y_s += _b
322
+
323
+ rotated_layer = layer.rotate(
324
+ angle,
325
+ expand=True,
326
+ center=(cx, cy),
327
+ resample=rotate_resample
328
+ )
329
+
330
+ xo = int((big_img.width - rotated_layer.width) // 2)
331
+ yo = int((big_img.height - rotated_layer.height) // 2)
332
+ big_img.paste(rotated_layer, (xo, yo), rotated_layer)
333
+
334
+ final_img = big_img.resize((width, height), downsample_resample)
335
+ final_np = np.array(final_img)
336
+ return final_np
337
+
338
+ def render_glyph_multi(original, computed_mask, texts):
339
+ mask_np = np.array(computed_mask.convert("L"))
340
+ contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
341
+ regions = []
342
+ for cnt in contours:
343
+ x, y, w, h = cv2.boundingRect(cnt)
344
+ if w * h < 50:
345
+ continue
346
+ regions.append((x, y, w, h, cnt))
347
+ regions = sorted(regions, key=lambda r: (r[1], r[0]))
348
+
349
+ render_img = Image.new("RGBA", original.size, (0, 0, 0, 0))
350
+ try:
351
+ base_font = ImageFont.truetype("resource/font/Arial-Unicode-Regular.ttf", 40)
352
+ except:
353
+ base_font = ImageFont.load_default()
354
+
355
+ for i, region in enumerate(regions):
356
+ if i >= len(texts):
357
+ break
358
+ text = texts[i].strip()
359
+ if not text:
360
+ continue
361
+ cnt = region[4]
362
+ polygon = cnt.reshape(-1, 2)
363
+ rendered_np = draw_glyph2(
364
+ font=base_font,
365
+ text=text,
366
+ polygon=polygon,
367
+ vertAng=10,
368
+ scale=1,
369
+ width=original.size[0],
370
+ height=original.size[1],
371
+ add_space=True,
372
+ scale_factor=1,
373
+ rotate_resample=Image.BICUBIC,
374
+ downsample_resample=Image.Resampling.LANCZOS
375
+ )
376
+ rendered_img = Image.fromarray(rendered_np, mode="RGBA")
377
+ render_img = Image.alpha_composite(render_img, rendered_img)
378
+ return render_img.convert("RGB")
379
+
380
+
381
+ def choose_concat_direction(height, width):
382
+ """
383
+ Selects the concatenation direction based on the original image's aspect ratio:
384
+ - If height is greater than width, horizontal concatenation is used.
385
+ - Otherwise, vertical concatenation is used.
386
+ """
387
+ return 'horizontal' if height > width else 'vertical'
388
+
389
+ def is_multiline_text(text):
390
+ """
391
+ Determines if the input text should be treated as multi-line based on line breaks.
392
+ """
393
+ lines = [line.strip() for line in text.splitlines() if line.strip()]
394
+ return len(lines) > 1
395
+
396
+ # =============================================================================
397
+ # Custom Mode: Unified function that handles both single-line and multi-line
398
+ # =============================================================================
399
+ def flux_demo_custom(original_image, drawn_mask, words, steps, guidance_scale, seed):
400
+ """
401
+ Unified custom mode Gradio main function:
402
+ - Automatically detects whether to use single-line or multi-line rendering based on input text
403
+ - If text contains line breaks, uses multi-line rendering
404
+ - If text is single line, uses single-line rendering
405
+ """
406
+ computed_mask = extract_mask(original_image, drawn_mask)
407
+
408
+ # Determine rendering mode based on text input
409
+ if is_multiline_text(words):
410
+ print("Using multi-line text rendering mode")
411
+ return flux_demo_custom_multiline(original_image, computed_mask, words, steps, guidance_scale, seed)
412
+ else:
413
+ print("Using single-line text rendering mode")
414
+ return flux_demo_custom_singleline(original_image, computed_mask, words, steps, guidance_scale, seed)
415
+
416
+ def flux_demo_custom_multiline(original_image, computed_mask, words, steps, guidance_scale, seed):
417
+ """
418
+ Multi-line rendering mode:
419
+ 1. Splits the user-input text into a list by line, with each line corresponding to a mask region.
420
+ 2. Calls render_glyph_multi for each independent region to render skewed/curved text, generating a rendered image.
421
+ 3. Selects the concatenation direction based on the original image's dimensions.
422
+ 4. Passes the concatenated images to run_inference, returning the generated result and cropped image.
423
+ """
424
+ texts = read_words_from_text(words)
425
+ render_img = render_glyph_multi(original_image, computed_mask, texts)
426
+ width, height = original_image.size
427
+ empty_mask = np.zeros((height, width), dtype=np.uint8)
428
+ direction = choose_concat_direction(height, width)
429
+ if direction == 'horizontal':
430
+ combined_image = np.hstack((np.array(render_img), np.array(original_image)))
431
+ combined_mask = np.hstack((empty_mask, np.array(computed_mask.convert("L"))))
432
+ else:
433
+ combined_image = np.vstack((np.array(render_img), np.array(original_image)))
434
+ combined_mask = np.vstack((empty_mask, np.array(computed_mask.convert("L"))))
435
+ combined_mask = cv2.cvtColor(combined_mask, cv2.COLOR_GRAY2RGB)
436
+ composite_image = Image.fromarray(combined_image)
437
+ composite_mask = Image.fromarray(combined_mask)
438
+ result = run_inference(composite_image, composite_mask, words, num_steps=steps, guidance_scale=guidance_scale, seed=seed)
439
+
440
+ # Crop the result, keeping only the scene image portion.
441
+ width, height = result.size
442
+ if direction == 'horizontal':
443
+ cropped_result = result.crop((width // 2, 0, width, height))
444
+ else:
445
+ cropped_result = result.crop((0, height // 2, width, height))
446
+
447
+ save_results(result, cropped_result, computed_mask, original_image, composite_image, words)
448
+ return cropped_result, composite_image, composite_mask
449
+
450
+ def flux_demo_custom_singleline(original_image, computed_mask, words, steps, guidance_scale, seed):
451
+ """
452
+ Single-line rendering mode:
453
+ 1. Concatenates user input text into a single line.
454
+ 2. Renders single-line text above the original image.
455
+ 3. Calls model inference and crops the result precisely.
456
+ """
457
+ # Process text, concatenate into single line
458
+ text_lines = read_words_from_text(words)
459
+ single_line_text = ' '.join(text_lines)
460
+
461
+ # Calculate dimensions and generate concatenated image and mask
462
+ w, h = original_image.size
463
+ text_height_ratio = 0.15625
464
+ text_render_height = int(w * text_height_ratio)
465
+
466
+ # Load font
467
+ try:
468
+ font = ImageFont.truetype("resource/font/Arial-Unicode-Regular.ttf", 60)
469
+ except IOError:
470
+ font = ImageFont.load_default()
471
+ print("Warning: Font not found, using default font.")
472
+
473
+ # Render single-line text image
474
+ text_render_pil = draw_glyph_flexible(font, single_line_text, width=w, height=text_render_height)
475
+ # Create pure black mask with same size as text rendering
476
+ text_mask_pil = Image.new("RGB", text_render_pil.size, "black")
477
+
478
+ # Always use vertical concatenation
479
+ composite_image = Image.fromarray(np.vstack((np.array(text_render_pil), np.array(original_image))))
480
+ composite_mask = Image.fromarray(np.vstack((np.array(text_mask_pil), np.array(computed_mask))))
481
+
482
+ # Call model inference
483
+ full_result = run_inference(composite_image, composite_mask, words, num_steps=steps, guidance_scale=guidance_scale, seed=seed)
484
+
485
+ # Crop result proportionally, keeping only the scene image portion
486
+ res_w, res_h = full_result.size
487
+ orig_h = h # Original scene image height
488
+ # Calculate crop line top edge position
489
+ crop_top_edge = int(res_h * (text_render_height / (orig_h + text_render_height)))
490
+ cropped_result = full_result.crop((0, crop_top_edge, res_w, res_h))
491
+
492
+ save_results(full_result, cropped_result, computed_mask, original_image, composite_image, words)
493
+ return cropped_result, composite_image, composite_mask
494
+
495
+ def save_results(result, cropped_result, computed_mask, original_image, composite_image, words):
496
+ """
497
+ Save all related images and text files
498
+ """
499
+ os.makedirs("outputs_my", exist_ok=True)
500
+ os.makedirs("outputs_my/crop", exist_ok=True)
501
+ os.makedirs("outputs_my/mask", exist_ok=True)
502
+ os.makedirs("outputs_my/ori", exist_ok=True)
503
+ os.makedirs("outputs_my/composite", exist_ok=True)
504
+ os.makedirs("outputs_my/txt", exist_ok=True)
505
+
506
+ seq = get_next_seq_number()
507
+ result_filename = os.path.join("outputs_my", f"result_{seq}.png")
508
+ crop_filename = os.path.join("outputs_my", "crop", f"crop_{seq}.png")
509
+ mask_filename = os.path.join("outputs_my", "mask", f"mask_{seq}.png")
510
+ ori_filename = os.path.join("outputs_my", "ori", f"ori_{seq}.png")
511
+ composite_filename = os.path.join("outputs_my", "composite", f"composite_{seq}.png")
512
+ txt_filename = os.path.join("outputs_my", "txt", f"words_{seq}.txt")
513
+
514
+ # Save images
515
+ result.save(result_filename)
516
+ cropped_result.save(crop_filename)
517
+ computed_mask.save(mask_filename)
518
+ original_image.save(ori_filename)
519
+ composite_image.save(composite_filename)
520
+ with open(txt_filename, "w", encoding="utf-8") as f:
521
+ f.write(words)
522
+
523
+ # =============================================================================
524
+ # Gradio Interface
525
+ # =============================================================================
526
+ with gr.Blocks(title="Flux Inference Demo") as demo:
527
+ gr.Markdown("## Flux Inference Demo")
528
+ with gr.Tabs():
529
+ with gr.TabItem("Custom Mode"):
530
+ with gr.Row():
531
+ with gr.Column(scale=1, min_width=350):
532
+ gr.Markdown("### Image Input")
533
+ original_image_custom = gr.Image(type="pil", label="Upload Original Image")
534
+ gr.Markdown("### Draw Mask on Image")
535
+ mask_drawing_custom = gr.Image(type="pil", label="Draw Mask on Original Image", tool="sketch")
536
+
537
+ with gr.Column(scale=1, min_width=350):
538
+ gr.Markdown("### Parameter Settings")
539
+ words_custom = gr.Textbox(
540
+ lines=5,
541
+ placeholder="Enter text here (single line recommended, faster and stronger).\nMultiple lines are supported, with each line rendered in corresponding mask regions.",
542
+ label="Text Input"
543
+ )
544
+ steps_custom = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Inference Steps")
545
+ guidance_scale_custom = gr.Slider(minimum=1, maximum=50, step=1, value=30, label="Guidance Scale")
546
+ seed_custom = gr.Number(value=42, label="Random Seed")
547
+ run_custom = gr.Button("Generate Results")
548
+
549
+ with gr.Tabs():
550
+ with gr.TabItem("Generated Results"):
551
+ output_result_custom = gr.Image(type="pil", label="Generated Results")
552
+ with gr.TabItem("Input Preview"):
553
+ output_composite_custom = gr.Image(type="pil", label="Concatenated Original Image")
554
+ output_mask_custom = gr.Image(type="pil", label="Concatenated Mask")
555
+
556
+ original_image_custom.change(fn=lambda x: x, inputs=original_image_custom, outputs=mask_drawing_custom)
557
+ run_custom.click(fn=flux_demo_custom,
558
+ inputs=[original_image_custom, mask_drawing_custom, words_custom, steps_custom, guidance_scale_custom, seed_custom],
559
+ outputs=[output_result_custom, output_composite_custom, output_mask_custom])
560
+
561
+ with gr.TabItem("Normal Mode"):
562
+ with gr.Row():
563
+ with gr.Column(scale=1, min_width=350):
564
+ gr.Markdown("### Image Input")
565
+ image_normal = gr.Image(type="pil", label="Image Input")
566
+ gr.Markdown("### Mask Input")
567
+ mask_normal = gr.Image(type="pil", label="Mask Input")
568
+ with gr.Column(scale=1, min_width=350):
569
+ gr.Markdown("### Parameter Settings")
570
+ words_normal = gr.Textbox(lines=5, placeholder="Please enter words here, one per line", label="Text List")
571
+ steps_normal = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Inference Steps")
572
+ guidance_scale_normal = gr.Slider(minimum=1, maximum=50, step=1, value=30, label="Guidance Scale")
573
+ seed_normal = gr.Number(value=42, label="Random Seed")
574
+ run_normal = gr.Button("Generate Results")
575
+ output_normal = gr.Image(type="pil", label="Generated Results")
576
+ run_normal.click(fn=flux_demo_normal,
577
+ inputs=[image_normal, mask_normal, words_normal, steps_normal, guidance_scale_normal, seed_normal],
578
+ outputs=output_normal)
579
+
580
+ gr.Markdown(
581
+ """
582
+ ### Instructions
583
+ - **Custom Mode**:
584
+ - Upload an original image, then draw a mask on it
585
+ - **Single-line mode**: Enter text without line breaks - all text will be joined and rendered as one line above the image
586
+ - **Multi-line mode**: Enter text with line breaks - each line will be rendered in the corresponding mask region with skewed/curved effects
587
+ - The system automatically detects which mode to use based on your text input
588
+ - **Normal Mode**: Directly upload an image, mask, and a list of words to generate the result image.
589
+ """
590
+ )
591
+
592
+ if __name__ == "__main__":
593
+ check_min_version("0.30.1")
594
+ demo.launch()