diffusion / gen_pic.py
starriver030515's picture
Upload folder using huggingface_hub
a501a0c verified
import argparse
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import os
import json
parser = argparse.ArgumentParser(description="Diffusion Pipeline with Arguments")
parser.add_argument(
"--json_filename",
type=str,
required=True,
help="Path to the JSON file containing text data",
)
parser.add_argument(
"--cuda", type=int, required=True, help="CUDA device to use for processing"
)
args = parser.parse_args()
json_filename = args.json_filename
cuda_device = f"cuda:{args.cuda}"
print(json_filename, cuda_device)
model_path = "./sdxl"
image_dir = "/mnt/petrelfs/zhuchenglin/LLaVA/playground/data/LLaVA-Pretrain/images"
if not os.path.exists(image_dir):
os.makedirs(image_dir)
base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
# base.scheduler.step_schedule = {
# "start": 0.5,
# "end": 0.0,
# "interpolation_type": "linear",
# }
base.to(cuda_device)
refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
# refiner.scheduler.step_schedule = {
# "start": 0.5,
# "end": 0.0,
# "interpolation_type": "linear",
# }
refiner.to(cuda_device)
with open(json_filename, "r") as f:
text_data = json.load(f)
n_steps = 60
high_noise_frac = 0.8
guidance_scale = 20
for text in text_data:
image = base(
prompt=text["conversations"][1]["value"],
num_inference_steps=n_steps,
denoising_end=high_noise_frac,
output_type="latent",
guidance_scale=guidance_scale,
).images
image = refiner(
prompt=text["conversations"][1]["value"],
num_inference_steps=n_steps,
denoising_start=high_noise_frac,
image=image,
guidance_scale=guidance_scale,
).images[0]
subdir = text["image"].split("/")[0]
if not os.path.exists(os.path.join(image_dir, subdir)):
os.makedirs(os.path.join(image_dir, subdir))
image_path = os.path.join(image_dir, text["image"])
image.save(image_path)
print("所有图像已成功生成并保存。")