File size: 5,824 Bytes
137645c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import os
from PIL import Image
import gradio as gr
import spaces
from imagenet_en_cn import IMAGENET_1K_CLASSES
from omegaconf import OmegaConf
from huggingface_hub import snapshot_download

import torch
# from transformers import T5EncoderModel, AutoTokenizer

from pixelflow.scheduling_pixelflow import PixelFlowScheduler
from pixelflow.pipeline_pixelflow import PixelFlowPipeline
from pixelflow.utils import config as config_utils
from pixelflow.utils.misc import seed_everything


parser = argparse.ArgumentParser(description='Gradio Demo', add_help=False)
parser.add_argument('--checkpoint', type=str, help='checkpoint folder path')
parser.add_argument('--class_cond', action='store_true', help='use class conditional generation')
args = parser.parse_args()

# deploy
args.checkpoint = "pixelflow_c2i"
args.class_cond = True


if args.class_cond:
    output_dir = args.checkpoint
    if not os.path.exists(output_dir):
        snapshot_download(repo_id="ShoufaChen/PixelFlow-Class2Image", local_dir=output_dir)
    config = OmegaConf.load(f"{output_dir}/config.yaml")
    model = config_utils.instantiate_from_config(config.model)
    print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
    text_encoder = None
    tokenizer = None
    resolution = 256
    NUM_EXAMPLES = 4
else:
    raise NotImplementedError("Please run locally.")
    config = OmegaConf.load(f"{output_dir}/config.yaml")
    model = config_utils.instantiate_from_config(config.model).to(device)
    print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
    text_encoder = T5EncoderModel.from_pretrained("google/flan-t5-xl").to(device)
    tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
    resolution = 1024
    NUM_EXAMPLES = 1
model.load_state_dict(ckpt, strict=True)
model.eval()

print(f"outside space.GPU. {torch.cuda.is_available()=}")
if torch.cuda.is_available():
    model = model.cuda()
    device = torch.device("cuda")
else:
    raise ValueError("No GPU")

scheduler = PixelFlowScheduler(config.scheduler.num_train_timesteps, num_stages=config.scheduler.num_stages, gamma=-1/3)

pipeline = PixelFlowPipeline(
    scheduler,
    model,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    max_token_length=512,
)

@spaces.GPU
def infer(use_ode_dopri5, noise_shift, cfg_scale, class_label, seed, *num_steps_per_stage):
    print(f"inside space.GPU. {torch.cuda.is_available()=}")
    seed_everything(seed)
    with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad():
        samples = pipeline(
            prompt=[class_label] * NUM_EXAMPLES,
            height=resolution,
            width=resolution,
            num_inference_steps=list(num_steps_per_stage),
            guidance_scale=cfg_scale,         # The guidance for the first frame, set it to 7 for 384p variant
            device=device,
            shift=noise_shift,
            use_ode_dopri5=use_ode_dopri5,
        )
    samples = (samples * 255).round().astype("uint8")
    samples = [Image.fromarray(sample) for sample in samples]
    return samples


css = """
h1 {
    text-align: center;
    display: block;
}

.follow-link {
    margin-top: 0.8em;
    font-size: 1em;
    text-align: center;
}
"""


with gr.Blocks(css=css) as demo:
    gr.Markdown("# PixelFlow: Pixel-Space Generative Models with Flow")
    gr.HTML("""
        <div class="follow-link">
            For text-to-image generation, please follow 
            <a href="https://github.com/ShoufaChen/PixelFlow/tree/main?tab=readme-ov-file#demo">text-to-image</a>.
            For more details, refer to our 
                <a href="https://arxiv.org/abs/2504.07963">arXiv paper</a> and <a href="https://github.com/ShoufaChen/PixelFlow">GitHub repo</a>.
        </div>
    """)

    with gr.Tabs():
        with gr.TabItem('Generate'):
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        if args.class_cond:
                            user_input = gr.Dropdown(
                                list(IMAGENET_1K_CLASSES.values()),
                                value='daisy [雏菊]',
                                type="index", label='ImageNet-1K Class'
                            )
                        else:
                            # text input
                            user_input = gr.Textbox(label='Enter your prompt', show_label=False, max_lines=1, placeholder="Enter your prompt",)
                    ode_dopri5 = gr.Checkbox(label="Dopri5 ODE", info="Use Dopri5 ODE solver")
                    noise_shift = gr.Slider(minimum=1.0, maximum=100.0, step=1, value=1.0, label='Noise Shift')
                    cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
                    num_steps_per_stage = []
                    for stage_idx in range(config.scheduler.num_stages):
                        num_steps = gr.Slider(minimum=1, maximum=100, step=1, value=10, label=f'Num Inference Steps (Stage {stage_idx})')
                        num_steps_per_stage.append(num_steps)
                    seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
                    button = gr.Button("Generate", variant="primary")
                with gr.Column():
                    output = gr.Gallery(label='Generated Images', height=700)
                    button.click(infer, inputs=[ode_dopri5, noise_shift, cfg_scale, user_input, seed, *num_steps_per_stage], outputs=[output])
    demo.queue()
    demo.launch(share=True, debug=True)