import argparse from diffusers import DiffusionPipeline import torch from PIL import Image import os import json parser = argparse.ArgumentParser(description="Diffusion Pipeline with Arguments") parser.add_argument( "--json_filename", type=str, required=True, help="Path to the JSON file containing text data", ) parser.add_argument( "--cuda", type=int, required=True, help="CUDA device to use for processing" ) args = parser.parse_args() json_filename = args.json_filename cuda_device = f"cuda:{args.cuda}" print(json_filename, cuda_device) model_path = "./sdxl" image_dir = "/mnt/petrelfs/zhuchenglin/LLaVA/playground/data/LLaVA-Pretrain/images" if not os.path.exists(image_dir): os.makedirs(image_dir) base = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True, ) # base.scheduler.step_schedule = { # "start": 0.5, # "end": 0.0, # "interpolation_type": "linear", # } base.to(cuda_device) refiner = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-refiner-1.0", text_encoder_2=base.text_encoder_2, vae=base.vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16", ) # refiner.scheduler.step_schedule = { # "start": 0.5, # "end": 0.0, # "interpolation_type": "linear", # } refiner.to(cuda_device) with open(json_filename, "r") as f: text_data = json.load(f) n_steps = 60 high_noise_frac = 0.8 guidance_scale = 20 for text in text_data: image = base( prompt=text["conversations"][1]["value"], num_inference_steps=n_steps, denoising_end=high_noise_frac, output_type="latent", guidance_scale=guidance_scale, ).images image = refiner( prompt=text["conversations"][1]["value"], num_inference_steps=n_steps, denoising_start=high_noise_frac, image=image, guidance_scale=guidance_scale, ).images[0] subdir = text["image"].split("/")[0] if not os.path.exists(os.path.join(image_dir, subdir)): os.makedirs(os.path.join(image_dir, subdir)) image_path = os.path.join(image_dir, text["image"]) image.save(image_path) print("所有图像已成功生成并保存。")