File size: 2,298 Bytes
a501a0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import os
import json

parser = argparse.ArgumentParser(description="Diffusion Pipeline with Arguments")

parser.add_argument(
    "--json_filename",
    type=str,
    required=True,
    help="Path to the JSON file containing text data",
)
parser.add_argument(
    "--cuda", type=int, required=True, help="CUDA device to use for processing"
)

args = parser.parse_args()
json_filename = args.json_filename
cuda_device = f"cuda:{args.cuda}"
print(json_filename, cuda_device)
model_path = "./sdxl"
image_dir = "/mnt/petrelfs/zhuchenglin/LLaVA/playground/data/LLaVA-Pretrain/images"
if not os.path.exists(image_dir):
    os.makedirs(image_dir)

base = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
)
# base.scheduler.step_schedule = {
#     "start": 0.5,
#     "end": 0.0,
#     "interpolation_type": "linear",
# }
base.to(cuda_device)

refiner = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
)
# refiner.scheduler.step_schedule = {
#     "start": 0.5,
#     "end": 0.0,
#     "interpolation_type": "linear",
# }
refiner.to(cuda_device)

with open(json_filename, "r") as f:
    text_data = json.load(f)

n_steps = 60
high_noise_frac = 0.8
guidance_scale = 20
for text in text_data:
    image = base(
        prompt=text["conversations"][1]["value"],
        num_inference_steps=n_steps,
        denoising_end=high_noise_frac,
        output_type="latent",
        guidance_scale=guidance_scale,
    ).images

    image = refiner(
        prompt=text["conversations"][1]["value"],
        num_inference_steps=n_steps,
        denoising_start=high_noise_frac,
        image=image,
        guidance_scale=guidance_scale,
    ).images[0]
    subdir = text["image"].split("/")[0]
    if not os.path.exists(os.path.join(image_dir, subdir)):
        os.makedirs(os.path.join(image_dir, subdir))
    image_path = os.path.join(image_dir, text["image"])
    image.save(image_path)

print("所有图像已成功生成并保存。")