Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,824 Bytes
137645c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import argparse
import os
from PIL import Image
import gradio as gr
import spaces
from imagenet_en_cn import IMAGENET_1K_CLASSES
from omegaconf import OmegaConf
from huggingface_hub import snapshot_download
import torch
# from transformers import T5EncoderModel, AutoTokenizer
from pixelflow.scheduling_pixelflow import PixelFlowScheduler
from pixelflow.pipeline_pixelflow import PixelFlowPipeline
from pixelflow.utils import config as config_utils
from pixelflow.utils.misc import seed_everything
parser = argparse.ArgumentParser(description='Gradio Demo', add_help=False)
parser.add_argument('--checkpoint', type=str, help='checkpoint folder path')
parser.add_argument('--class_cond', action='store_true', help='use class conditional generation')
args = parser.parse_args()
# deploy
args.checkpoint = "pixelflow_c2i"
args.class_cond = True
if args.class_cond:
output_dir = args.checkpoint
if not os.path.exists(output_dir):
snapshot_download(repo_id="ShoufaChen/PixelFlow-Class2Image", local_dir=output_dir)
config = OmegaConf.load(f"{output_dir}/config.yaml")
model = config_utils.instantiate_from_config(config.model)
print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
text_encoder = None
tokenizer = None
resolution = 256
NUM_EXAMPLES = 4
else:
raise NotImplementedError("Please run locally.")
config = OmegaConf.load(f"{output_dir}/config.yaml")
model = config_utils.instantiate_from_config(config.model).to(device)
print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
text_encoder = T5EncoderModel.from_pretrained("google/flan-t5-xl").to(device)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
resolution = 1024
NUM_EXAMPLES = 1
model.load_state_dict(ckpt, strict=True)
model.eval()
print(f"outside space.GPU. {torch.cuda.is_available()=}")
if torch.cuda.is_available():
model = model.cuda()
device = torch.device("cuda")
else:
raise ValueError("No GPU")
scheduler = PixelFlowScheduler(config.scheduler.num_train_timesteps, num_stages=config.scheduler.num_stages, gamma=-1/3)
pipeline = PixelFlowPipeline(
scheduler,
model,
text_encoder=text_encoder,
tokenizer=tokenizer,
max_token_length=512,
)
@spaces.GPU
def infer(use_ode_dopri5, noise_shift, cfg_scale, class_label, seed, *num_steps_per_stage):
print(f"inside space.GPU. {torch.cuda.is_available()=}")
seed_everything(seed)
with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad():
samples = pipeline(
prompt=[class_label] * NUM_EXAMPLES,
height=resolution,
width=resolution,
num_inference_steps=list(num_steps_per_stage),
guidance_scale=cfg_scale, # The guidance for the first frame, set it to 7 for 384p variant
device=device,
shift=noise_shift,
use_ode_dopri5=use_ode_dopri5,
)
samples = (samples * 255).round().astype("uint8")
samples = [Image.fromarray(sample) for sample in samples]
return samples
css = """
h1 {
text-align: center;
display: block;
}
.follow-link {
margin-top: 0.8em;
font-size: 1em;
text-align: center;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# PixelFlow: Pixel-Space Generative Models with Flow")
gr.HTML("""
<div class="follow-link">
For text-to-image generation, please follow
<a href="https://github.com/ShoufaChen/PixelFlow/tree/main?tab=readme-ov-file#demo">text-to-image</a>.
For more details, refer to our
<a href="https://arxiv.org/abs/2504.07963">arXiv paper</a> and <a href="https://github.com/ShoufaChen/PixelFlow">GitHub repo</a>.
</div>
""")
with gr.Tabs():
with gr.TabItem('Generate'):
with gr.Row():
with gr.Column():
with gr.Row():
if args.class_cond:
user_input = gr.Dropdown(
list(IMAGENET_1K_CLASSES.values()),
value='daisy [雏菊]',
type="index", label='ImageNet-1K Class'
)
else:
# text input
user_input = gr.Textbox(label='Enter your prompt', show_label=False, max_lines=1, placeholder="Enter your prompt",)
ode_dopri5 = gr.Checkbox(label="Dopri5 ODE", info="Use Dopri5 ODE solver")
noise_shift = gr.Slider(minimum=1.0, maximum=100.0, step=1, value=1.0, label='Noise Shift')
cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
num_steps_per_stage = []
for stage_idx in range(config.scheduler.num_stages):
num_steps = gr.Slider(minimum=1, maximum=100, step=1, value=10, label=f'Num Inference Steps (Stage {stage_idx})')
num_steps_per_stage.append(num_steps)
seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
button = gr.Button("Generate", variant="primary")
with gr.Column():
output = gr.Gallery(label='Generated Images', height=700)
button.click(infer, inputs=[ode_dopri5, noise_shift, cfg_scale, user_input, seed, *num_steps_per_stage], outputs=[output])
demo.queue()
demo.launch(share=True, debug=True) |