|
import argparse |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
from PIL import Image |
|
import os |
|
import json |
|
|
|
parser = argparse.ArgumentParser(description="Diffusion Pipeline with Arguments") |
|
|
|
parser.add_argument( |
|
"--json_filename", |
|
type=str, |
|
required=True, |
|
help="Path to the JSON file containing text data", |
|
) |
|
parser.add_argument( |
|
"--cuda", type=int, required=True, help="CUDA device to use for processing" |
|
) |
|
|
|
args = parser.parse_args() |
|
json_filename = args.json_filename |
|
cuda_device = f"cuda:{args.cuda}" |
|
print(json_filename, cuda_device) |
|
model_path = "./sdxl" |
|
image_dir = "/mnt/petrelfs/zhuchenglin/LLaVA/playground/data/LLaVA-Pretrain/images" |
|
if not os.path.exists(image_dir): |
|
os.makedirs(image_dir) |
|
|
|
base = DiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
use_safetensors=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
base.to(cuda_device) |
|
|
|
refiner = DiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-refiner-1.0", |
|
text_encoder_2=base.text_encoder_2, |
|
vae=base.vae, |
|
torch_dtype=torch.float16, |
|
use_safetensors=True, |
|
variant="fp16", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
refiner.to(cuda_device) |
|
|
|
with open(json_filename, "r") as f: |
|
text_data = json.load(f) |
|
|
|
n_steps = 60 |
|
high_noise_frac = 0.8 |
|
guidance_scale = 20 |
|
for text in text_data: |
|
image = base( |
|
prompt=text["conversations"][1]["value"], |
|
num_inference_steps=n_steps, |
|
denoising_end=high_noise_frac, |
|
output_type="latent", |
|
guidance_scale=guidance_scale, |
|
).images |
|
|
|
image = refiner( |
|
prompt=text["conversations"][1]["value"], |
|
num_inference_steps=n_steps, |
|
denoising_start=high_noise_frac, |
|
image=image, |
|
guidance_scale=guidance_scale, |
|
).images[0] |
|
subdir = text["image"].split("/")[0] |
|
if not os.path.exists(os.path.join(image_dir, subdir)): |
|
os.makedirs(os.path.join(image_dir, subdir)) |
|
image_path = os.path.join(image_dir, text["image"]) |
|
image.save(image_path) |
|
|
|
print("所有图像已成功生成并保存。") |
|
|