Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import open_clip | |
import torch | |
import taming.models.vqgan | |
import ml_collections | |
import einops | |
import random | |
import pathlib | |
import subprocess | |
import shlex | |
import wget | |
# Model | |
from libs.muse import MUSE | |
import utils | |
import numpy as np | |
from PIL import Image | |
print("cuda available:",torch.cuda.is_available()) | |
print("cuda device count:",torch.cuda.device_count()) | |
print("cuda device name:",torch.cuda.get_device_name(0)) | |
# print(os.system("nvidia-smi")) | |
print(os.system("nvcc --version")) | |
empty_context = np.load("assets/contexts/empty_context.npy") | |
print("downloading cc3m-285000.ckpt") | |
os.makedirs("assets/ckpts/cc3m-285000.ckpt",exist_ok=True) | |
os.system("ls") | |
wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth","assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth") | |
wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/optimizer.pth","assets/ckpts/cc3m-285000.ckpt/optimizer.pth") | |
wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet.pth","assets/ckpts/cc3m-285000.ckpt/nnet.pth") | |
wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth","assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth") | |
wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/step.pth","assets/ckpts/cc3m-285000.ckpt/step.pth") | |
wget.download("https://huggingface.co/zideliu/vqgan/resolve/main/vqgan_jax_strongaug.ckpt","assets/vqgan_jax_strongaug.ckpt") | |
def set_seed(seed: int): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def d(**kwargs): | |
"""Helper of creating a config dict.""" | |
return ml_collections.ConfigDict(initial_dictionary=kwargs) | |
def get_config(): | |
config = ml_collections.ConfigDict() | |
config.seed = 1234 | |
config.z_shape = (8, 16, 16) | |
config.autoencoder = d( | |
config_file='vq-f16-jax.yaml', | |
) | |
config.resume_root="assets/ckpts/cc3m-285000.ckpt" | |
config.adapter_path=None | |
config.optimizer = d( | |
name='adamw', | |
lr=0.0002, | |
weight_decay=0.03, | |
betas=(0.99, 0.99), | |
) | |
config.lr_scheduler = d( | |
name='customized', | |
warmup_steps=5000 | |
) | |
config.nnet = d( | |
name='uvit_t2i_vq', | |
img_size=16, | |
codebook_size=1024, | |
in_chans=4, | |
embed_dim=1152, | |
depth=28, | |
num_heads=16, | |
mlp_ratio=4, | |
qkv_bias=False, | |
clip_dim=1280, | |
num_clip_token=77, | |
use_checkpoint=True, | |
skip=True, | |
d_prj=32, | |
is_shared=False | |
) | |
config.muse = d( | |
ignore_ind=-1, | |
smoothing=0.1, | |
gen_temp=4.5 | |
) | |
config.sample = d( | |
sample_steps=36, | |
n_samples=50, | |
mini_batch_size=8, | |
cfg=True, | |
linear_inc_scale=True, | |
scale=10., | |
path='', | |
lambdaA=2.0, # Stage I: 2.0; Stage II: TODO | |
lambdaB=5.0, # Stage I: 5.0; Stage II: TODO | |
) | |
return config | |
def cfg_nnet(x, context, scale=None,lambdaA=None,lambdaB=None): | |
_cond = nnet_ema(x, context=context) | |
_cond_w_adapter = nnet_ema(x,context=context,use_adapter=True) | |
_empty_context = torch.tensor(empty_context, device=device) | |
_empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0)) | |
_uncond = nnet_ema(x, context=_empty_context) | |
res = _cond + scale * (_cond - _uncond) | |
if lambdaA is not None: | |
res = _cond_w_adapter + lambdaA*(_cond_w_adapter - _cond) + lambdaB*(_cond - _uncond) | |
return res | |
def unprocess(x): | |
x.clamp_(0., 1.) | |
return x | |
config = get_config() | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
# Load open_clip and vq model | |
prompt_model,_,_ = open_clip.create_model_and_transforms('ViT-bigG-14', 'laion2b_s39b_b160k') | |
prompt_model = prompt_model.to(device) | |
prompt_model.eval() | |
tokenizer = open_clip.get_tokenizer('ViT-bigG-14') | |
vq_model = taming.models.vqgan.get_model('vq-f16-jax.yaml') | |
vq_model.eval() | |
vq_model.requires_grad_(False) | |
vq_model.to(device) | |
## config | |
muse = MUSE(codebook_size=vq_model.n_embed, device=device, **config.muse) | |
train_state = utils.initialize_train_state(config, device) | |
train_state.resume(ckpt_root=config.resume_root) | |
nnet_ema = train_state.nnet_ema | |
nnet_ema.eval() | |
nnet_ema.requires_grad_(False) | |
nnet_ema.to(device) | |
style_ref = { | |
"None":None, | |
"0102":"style_adapter/0102.pth", | |
"0103":"style_adapter/0103.pth", | |
"0106":"style_adapter/0106.pth", | |
"0108":"style_adapter/0108.pth", | |
"0301":"style_adapter/0301.pth", | |
"0305":"style_adapter/0305.pth", | |
} | |
style_postfix ={ | |
"None":"", | |
"0102":" in watercolor painting style", | |
"0103":" in watercolor painting style", | |
"0106":" in line drawing style", | |
"0108":" in oil painting style", | |
"0301":" in 3d rendering style", | |
"0305":" in kid crayon drawing style", | |
} | |
def decode(_batch): | |
return vq_model.decode_code(_batch) | |
def process(prompt,num_samples,lambdaA,lambdaB,style,seed,sample_steps,image=None): | |
config.sample.lambdaA = lambdaA | |
config.sample.lambdaB = lambdaB | |
config.sample.sample_steps = sample_steps | |
print(style) | |
adapter_path = style_ref[style] | |
adapter_postfix = style_postfix[style] | |
print(f"load adapter path: {adapter_path}") | |
if adapter_path is not None: | |
nnet_ema.adapter.load_state_dict(torch.load(adapter_path)) | |
else: | |
config.sample.lambdaA=None | |
config.sample.lambdaB=None | |
print("load adapter Done!") | |
# Encode prompt | |
prompt = prompt+adapter_postfix | |
text_tokens = tokenizer(prompt).to(device) | |
text_embedding = prompt_model.encode_text(text_tokens) | |
text_embedding = text_embedding.repeat(num_samples, 1, 1) # B 77 1280 | |
print(text_embedding.shape) | |
print(f"lambdaA: {lambdaA}, lambdaB: {lambdaB}, sample_steps: {sample_steps}") | |
if seed==-1: | |
seed = random.randint(0,65535) | |
config.seed = seed | |
print(f"seed: {seed}") | |
set_seed(config.seed) | |
res = muse.generate(config,num_samples,cfg_nnet,decode,is_eval=True,context=text_embedding) | |
print(res.shape) | |
res = (res*255+0.5).clamp_(0,255).permute(0,2,3,1).to('cpu',torch.uint8).numpy() | |
im = [res[i] for i in range(num_samples)] | |
return im | |
block = gr.Blocks() | |
with block: | |
with gr.Row(): | |
gr.Markdown("## StyleDrop based on Muse (Inference Only) ") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt") | |
run_button = gr.Button(label="Run") | |
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) | |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=1234) | |
style = gr.Radio(choices=["0102","0103","0106","0108","0305","None"],type="value",value="None",label="Style") | |
with gr.Accordion("Advanced options",open=False): | |
lambdaA = gr.Slider(label="lambdaA", minimum=0.0, maximum=5.0, value=2.0, step=0.01) | |
lambdaB = gr.Slider(label="lambdaB", minimum=0.0, maximum=10.0, value=5.0, step=0.01) | |
sample_steps = gr.Slider(label="Sample steps", minimum=1, maximum=50, value=36, step=1) | |
image=gr.Image(value=None) | |
with gr.Column(): | |
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(columns=2, height='auto') | |
with gr.Row(): | |
examples = [ | |
[ | |
"A banana on the table", | |
1,2.0,5.0,"0103",1234,36, | |
"data/image_01_03.jpg", | |
], | |
[ | |
"A cow", | |
1,2.0,5.0,"0102",1234,36, | |
"data/image_01_02.jpg", | |
], | |
[ | |
"A portrait of tabby cat", | |
1,2.0,5.0,"0106",1234,36, | |
"data/image_01_06.jpg", | |
], | |
[ | |
"A church in the field", | |
1,2.0,5.0,"0108",1234,36, | |
"data/image_01_08.jpg", | |
], | |
[ | |
"A Christmas tree", | |
1,2.0,5.0,"0305",1234,36, | |
"data/image_03_05.jpg", | |
] | |
] | |
gr.Examples(examples=examples, | |
fn=process, | |
inputs=[ | |
prompt, | |
num_samples,lambdaA,lambdaB,style,seed,sample_steps,image, | |
], | |
outputs=result_gallery, | |
cache_examples=os.getenv('SYSTEM') == 'spaces' | |
) | |
ips = [prompt,num_samples,lambdaA,lambdaB,style,seed,sample_steps,image] | |
run_button.click( | |
fn=process, | |
inputs=ips, | |
outputs=[result_gallery] | |
) | |
block.queue().launch(share=False) | |