image-variation-experiments / inference_pixart_flux_redux.py
ivand-all3d's picture
Initial commit
1f61707
raw
history blame
4.12 kB
import argparse
import time
import torch
from diffusers import PixArtAlphaPipeline, PixArtTransformer2DModel
from diffusers.pipelines.flux import FluxPriorReduxPipeline
from transformers import SiglipImageProcessor
from pathlib import Path
from PIL import Image
pipe = None
redux = None
redux_embedder = None
def generate(prompt, image_prompt=None, guidance_scale=2, num_images=4, resolution=512):
with torch.no_grad():
clip_image_processor = SiglipImageProcessor(size={"height": 384, "width": 384})
clip_pixel_values = clip_image_processor.preprocess(
image_prompt.convert("RGB"), return_tensors="pt"
).pixel_values.to("cuda", dtype=torch.bfloat16)
image_prompt_latents = redux.image_encoder(clip_pixel_values).last_hidden_state
image_prompt_embeds = redux_embedder(image_prompt_latents).image_embeds
prompt_embeds = image_prompt_embeds[:, :120, :] * 0.04
attention_mask = torch.ones(prompt_embeds.shape[0], prompt_embeds.shape[1]).to("cuda")
images = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=attention_mask,
negative_prompt="",
height=resolution,
width=resolution,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images,
).images
# Concatenate all images horizontally
widths, heights = zip(*[img.size for img in images])
total_width = sum(widths) + len(images) - 1
max_height = max(heights)
out = Image.new('RGB', (total_width, max_height))
x_offset = 0
for img in images:
out.paste(img, (x_offset, 0))
x_offset += img.width + 1
# If an image prompt was provided, stack it above the generated images
if image_prompt is not None:
out_with_image_prompt = Image.new('RGB', (out.width, out.height + 1 + resolution))
resized_prompt = image_prompt.resize((resolution, resolution), Image.Resampling.BILINEAR)
out_with_image_prompt.paste(resized_prompt, (0, 0))
out_with_image_prompt.paste(out, (0, resolution + 1))
out = out_with_image_prompt
Path("image-outputs").mkdir(parents=True, exist_ok=True)
output_filename = f"image-outputs/{prompt[:40].replace(' ', '_')}.{int(time.time())}.png"
out.save(output_filename)
print(f"Saved output to {output_filename}")
def main():
parser = argparse.ArgumentParser(
description="Generate images using an image and a text prompt (PixArt Flux Redux)."
)
parser.add_argument("--prompt", type=str, default="",
help='The text prompt for image generation (default: "")')
parser.add_argument("--image_prompt", type=str, default=None,
help="Path to an optional image to use as a prompt")
parser.add_argument("--guidance_scale", type=float, default=2,
help="Guidance scale for image generation (default: 2)")
parser.add_argument("--num_images", type=int, default=4,
help="Number of images to generate (default: 4)")
parser.add_argument("--resolution", type=int, default=512,
help="Resolution for generated images (default: 512)")
args = parser.parse_args()
global pipe, redux, redux_embedder
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-512x512", transformer=None, torch_dtype=torch.bfloat16
)
transformer = PixArtTransformer2DModel.from_pretrained("pixart-flux-redux", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
redux = FluxPriorReduxPipeline.from_pretrained("FLUX.1-Redux-dev", torch_dtype=torch.bfloat16)
redux_embedder = redux.image_embedder
redux.to("cuda")
pipe.to("cuda")
img_prompt = Image.open(args.image_prompt) if args.image_prompt else None
generate(args.prompt, image_prompt=img_prompt, guidance_scale=args.guidance_scale,
num_images=args.num_images, resolution=args.resolution)
if __name__ == "__main__":
main()