ip-composer / IP_Composer /create_grids.py
linoyts's picture
linoyts HF Staff
Upload 64 files
c025a3d verified
raw
history blame contribute delete
8.57 kB
import argparse
import os
import json
import itertools
from PIL import Image, ImageDraw, ImageFont
def wrap_text(text, max_width, draw, font):
"""
Wrap the text to fit within the given width by breaking it into lines.
"""
lines = []
words = text.split(' ')
current_line = []
for word in words:
current_line.append(word)
line_width = draw.textbbox((0, 0), ' '.join(current_line), font=font)[2]
if line_width > max_width:
current_line.pop()
lines.append(' '.join(current_line))
current_line = [word]
if current_line:
lines.append(' '.join(current_line))
return lines
def image_grid_with_titles(imgs, rows, cols, top_titles, left_titles, margin=20):
assert len(imgs) == rows * cols
assert len(top_titles) == cols
assert len(left_titles) == rows
imgs = [img.resize((256, 256)) for img in imgs]
w, h = imgs[0].size
title_height = 50
title_width = 120
grid_width = cols * (w + margin) + title_width + margin
grid_height = rows * (h + margin) + title_height + margin
grid = Image.new('RGB', size=(grid_width, grid_height), color='white')
draw = ImageDraw.Draw(grid)
try:
font = ImageFont.truetype("arial.ttf", 20)
except IOError:
font = ImageFont.load_default()
for i, title in enumerate(top_titles):
wrapped_title = wrap_text(title, w, draw, font)
total_text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in wrapped_title])
y_offset = (title_height - total_text_height) // 2
for line in wrapped_title:
text_width = draw.textbbox((0, 0), line, font=font)[2]
x_offset = ((i * (w + margin)) + title_width + margin + (w - text_width) // 2)
draw.text((x_offset, y_offset), line, fill="black", font=font)
y_offset += draw.textbbox((0, 0), line, font=font)[3]
for i, title in enumerate(left_titles):
wrapped_title = wrap_text(title, title_width - 10, draw, font)
total_text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in wrapped_title])
y_offset = (i * (h + margin)) + title_height + (h - total_text_height) // 2 + margin
for line in wrapped_title:
text_width = draw.textbbox((0, 0), line, font=font)[2]
x_offset = (title_width - text_width) // 2
draw.text((x_offset, y_offset), line, fill="black", font=font)
y_offset += draw.textbbox((0, 0), line, font=font)[3]
for i, img in enumerate(imgs):
x_pos = (i % cols) * (w + margin) + title_width + margin
y_pos = (i // cols) * (h + margin) + title_height + margin
grid.paste(img, box=(x_pos, y_pos))
return grid
def create_grids(config):
num_samples = config["num_samples"]
concept_dirs = config["input_dirs_concepts"]
output_base_dir = config["output_base_dir"]
output_grid_dir = os.path.join(output_base_dir, "grids")
os.makedirs(output_grid_dir, exist_ok=True)
base_images = os.listdir(config["input_dir_base"])
if len(concept_dirs) == 1:
# Special case: Single concept
last_concept_dir = concept_dirs[0]
last_concept_images = os.listdir(last_concept_dir)
top_titles = ["Base Image", "Concept 1"] + ["Samples"] + [""] * (num_samples - 1)
left_titles = ["" for i in range(len(last_concept_images))]
def load_image(path):
return Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), color="white")
for base_image in base_images:
base_image_path = os.path.join(config["input_dir_base"], base_image)
images = []
for last_image in last_concept_images:
last_image_path = os.path.join(last_concept_dir, last_image)
row_images = [load_image(base_image_path), load_image(last_image_path)]
# Add generated samples for the current row
sample_dir = os.path.join(output_base_dir, f"{base_image}_to_{last_image}")
if os.path.exists(sample_dir):
sample_images = sorted(os.listdir(sample_dir))
row_images.extend([load_image(os.path.join(sample_dir, sample_image)) for sample_image in sample_images])
images.extend(row_images)
# Fill empty spaces to match the grid dimensions
total_required = len(left_titles) * len(top_titles)
if len(images) < total_required:
images.extend([Image.new("RGB", (256, 256), color="white")] * (total_required - len(images)))
# Create the grid
grid = image_grid_with_titles(
imgs=images,
rows=len(left_titles),
cols=len(top_titles),
top_titles=top_titles,
left_titles=left_titles
)
# Save the grid
grid_save_path = os.path.join(output_grid_dir, f"grid_base_{base_image}_concept1.png")
grid.save(grid_save_path)
print(f"Grid saved at {grid_save_path}")
else:
# General case: Multiple concepts
fixed_concepts = concept_dirs[:-1]
last_concept_dir = concept_dirs[-1]
last_concept_images = os.listdir(last_concept_dir)
top_titles = ["Base Image"] + [f"Concept {i+1}" for i in range(len(fixed_concepts))] + ["Last Concept"] + ["Samples"] + [""] * (num_samples - 1)
left_titles = ["" for i in range(len(last_concept_images))]
def load_image(path):
return Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), color="white")
fixed_concept_images = [os.listdir(concept_dir) for concept_dir in fixed_concepts]
for base_image in base_images:
base_image_path = os.path.join(config["input_dir_base"], base_image)
fixed_combinations = itertools.product(*fixed_concept_images)
for fixed_combination in fixed_combinations:
images = []
# Build fixed combination row
fixed_images = [load_image(base_image_path)]
for concept_dir, concept_image in zip(fixed_concepts, fixed_combination):
concept_image_path = os.path.join(concept_dir, concept_image)
fixed_images.append(load_image(concept_image_path))
# Iterate over last concept for rows
for last_image in last_concept_images:
last_image_path = os.path.join(last_concept_dir, last_image)
row_images = fixed_images + [load_image(last_image_path)]
# Add generated samples for the current row
sample_dir = os.path.join(output_base_dir, f"{base_image}_to_" + "_".join([f"{concept_image}" for concept_image in fixed_combination]) + f"_{last_image}")
if os.path.exists(sample_dir):
sample_images = sorted(os.listdir(sample_dir))
row_images.extend([load_image(os.path.join(sample_dir, sample_image)) for sample_image in sample_images])
images.extend(row_images)
# Fill empty spaces to match the grid dimensions
total_required = len(left_titles) * len(top_titles)
if len(images) < total_required:
images.extend([Image.new("RGB", (256, 256), color="white")] * (total_required - len(images)))
# Create the grid
grid = image_grid_with_titles(
imgs=images,
rows=len(left_titles),
cols=len(top_titles),
top_titles=top_titles,
left_titles=left_titles
)
# Save the grid
grid_save_path = os.path.join(output_grid_dir, f"grid_base_{base_image}_combo_{'_'.join(map(str, fixed_combination))}.png")
grid.save(grid_save_path)
print(f"Grid saved at {grid_save_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create image grids based on a configuration file.")
parser.add_argument("config_path", type=str, help="Path to the configuration JSON file.")
args = parser.parse_args()
# Load the configuration
with open(args.config_path, 'r') as f:
config = json.load(f)
if "num_samples" not in config:
config["num_samples"] = 4
create_grids(config)