File size: 5,817 Bytes
137645c
 
 
 
 
 
 
 
 
 
3761ee3
137645c
 
 
 
 
 
 
 
 
 
 
 
 
3761ee3
 
137645c
3761ee3
137645c
 
 
 
 
 
 
 
 
 
 
 
 
3761ee3
 
137645c
3761ee3
137645c
 
3761ee3
137645c
 
 
 
 
 
 
 
 
3761ee3
137645c
 
 
 
 
 
 
 
 
 
 
 
 
 
218c9e9
3761ee3
137645c
 
 
 
 
 
 
 
 
 
 
3761ee3
137645c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3761ee3
 
137645c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3761ee3
137645c
3761ee3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import argparse
import os
from PIL import Image
import gradio as gr
import spaces
from imagenet_en_cn import IMAGENET_1K_CLASSES
from omegaconf import OmegaConf
from huggingface_hub import snapshot_download

import torch
from transformers import T5EncoderModel, AutoTokenizer

from pixelflow.scheduling_pixelflow import PixelFlowScheduler
from pixelflow.pipeline_pixelflow import PixelFlowPipeline
from pixelflow.utils import config as config_utils
from pixelflow.utils.misc import seed_everything


parser = argparse.ArgumentParser(description='Gradio Demo', add_help=False)
parser.add_argument('--checkpoint', type=str, help='checkpoint folder path')
parser.add_argument('--class_cond', action='store_true', help='use class conditional generation')
args = parser.parse_args()

# deploy
args.checkpoint = "pixelflow_t2i"
args.class_cond = False

output_dir = args.checkpoint

if args.class_cond:
    if not os.path.exists(output_dir):
        snapshot_download(repo_id="ShoufaChen/PixelFlow-Class2Image", local_dir=output_dir)
    config = OmegaConf.load(f"{output_dir}/config.yaml")
    model = config_utils.instantiate_from_config(config.model)
    print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
    text_encoder = None
    tokenizer = None
    resolution = 256
    NUM_EXAMPLES = 4
else:
    if not os.path.exists(output_dir):
        snapshot_download(repo_id="ShoufaChen/PixelFlow-Text2Image", local_dir=output_dir)
    config = OmegaConf.load(f"{output_dir}/config.yaml")
    model = config_utils.instantiate_from_config(config.model)
    print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
    text_encoder = T5EncoderModel.from_pretrained("google/flan-t5-xl")
    tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
    resolution = 1024
    NUM_EXAMPLES = 1
model.load_state_dict(ckpt, strict=True)
model.eval()

print(f"outside space.GPU. {torch.cuda.is_available()=}")
if torch.cuda.is_available():
    model = model.cuda()
    text_encoder = text_encoder.cuda() if text_encoder else None
    device = torch.device("cuda")
else:
    raise ValueError("No GPU")

scheduler = PixelFlowScheduler(config.scheduler.num_train_timesteps, num_stages=config.scheduler.num_stages, gamma=-1/3)

pipeline = PixelFlowPipeline(
    scheduler,
    model,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    max_token_length=512,
)

@spaces.GPU(duration=120) 
def infer(noise_shift, cfg_scale, class_label, seed, *num_steps_per_stage):
    print(f"inside space.GPU. {torch.cuda.is_available()=}")
    seed_everything(seed)
    with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad():
        samples = pipeline(
            prompt=[class_label] * NUM_EXAMPLES,
            height=resolution,
            width=resolution,
            num_inference_steps=list(num_steps_per_stage),
            guidance_scale=cfg_scale,         # The guidance for the first frame, set it to 7 for 384p variant
            device=device,
            shift=noise_shift,
            use_ode_dopri5=False,
        )
    samples = (samples * 255).round().astype("uint8")
    samples = [Image.fromarray(sample) for sample in samples]
    return samples


css = """
h1 {
    text-align: center;
    display: block;
}

.follow-link {
    margin-top: 0.8em;
    font-size: 1em;
    text-align: center;
}
"""


with gr.Blocks(css=css) as demo:
    gr.Markdown("# PixelFlow: Pixel-Space Generative Models with Flow")
    gr.HTML("""
        <div class="follow-link">
            For online class-to-image generation, please try
            <a href="https://huggingface.co/spaces/ShoufaChen/PixelFlow-Class2Image">class-to-image</a>.
            For more details, refer to our 
                <a href="https://arxiv.org/abs/2504.07963">arXiv paper</a> and <a href="https://github.com/ShoufaChen/PixelFlow">GitHub repo</a>.
        </div>
    """)

    with gr.Tabs():
        with gr.TabItem('Generate'):
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        if args.class_cond:
                            user_input = gr.Dropdown(
                                list(IMAGENET_1K_CLASSES.values()),
                                value='daisy [ι›θŠ]',
                                type="index", label='ImageNet-1K Class'
                            )
                        else:
                            # text input
                            user_input = gr.Textbox(label='Enter your prompt', show_label=False, max_lines=1, placeholder="Enter your prompt",)
                    noise_shift = gr.Slider(minimum=1.0, maximum=100.0, step=1, value=1.0, label='Noise Shift')
                    cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
                    num_steps_per_stage = []
                    for stage_idx in range(config.scheduler.num_stages):
                        num_steps = gr.Slider(minimum=1, maximum=100, step=1, value=10, label=f'Num Inference Steps (Stage {stage_idx})')
                        num_steps_per_stage.append(num_steps)
                    seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
                    button = gr.Button("Generate", variant="primary")
                with gr.Column():
                    output = gr.Gallery(label='Generated Images', height=700)
                    button.click(infer, inputs=[noise_shift, cfg_scale, user_input, seed, *num_steps_per_stage], outputs=[output])
    demo.queue()
    demo.launch(share=True, debug=True)