weights2weights / app.py
multimodalart's picture
Update app.py
dc32163 verified
raw
history blame
4.95 kB
import gradio as gr
import sys
import os
import tqdm
sys.path.append(os.path.abspath(os.path.join("", "..")))
import torch
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
from utils import load_models, save_model_w2w, save_model_for_diffusers
from sampling import sample_weights
from huggingface_hub import snapshot_download
#global device
#global generator
#global unet
#global vae
#global text_encoder
#global tokenizer
#global noise_scheduler
device = "cuda:0"
generator = torch.Generator(device=device)
models_path = snapshot_download(repo_id="Snapchat/w2w")
mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device)
std = torch.load(f"{models_path}/std.pt").bfloat16().to(device)
v = torch.load(f"{models_path}/V.pt").bfloat16().to(device)
proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device)
df = torch.load(f"{models_path}/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
#global network
#def sample_model():
# global unet
# del unet
# global network
# unet, _, _, _, _ = load_models(device)
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
#global device
#global generator
#global unet
#global vae
#global text_encoder
#global tokenizer
#global noise_scheduler
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
return [image]
with gr.Blocks() as demo:
gr.Markdown("# <em>weights2weights</em> Demo")
with gr.Row():
with gr.Column():
files = gr.Files(
label="Upload a photo of your face to invert, or sample a new model",
file_types=["image"]
)
uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)
sample = gr.Button("Sample New Model")
with gr.Column(visible=False) as clear_button:
remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
prompt = gr.Textbox(label="Prompt",
info="Make sure to include 'sks person'" ,
placeholder="sks person",
value="sks person")
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
seed = gr.Number(value=5, label="Seed", interactive=True)
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
submit = gr.Button("Submit")
with gr.Column():
gallery = gr.Gallery(label="Generated Images")
#sample.click(fn=sample_model)
submit.click(fn=inference,
inputs=[prompt, negative_prompt, cfg, steps, seed],
outputs=gallery)
demo.launch(share=True)