File size: 4,916 Bytes
eb710fe
 
 
 
 
 
 
 
 
 
 
 
 
 
c1cc7f8
 
 
 
 
 
 
eb710fe
 
 
 
 
 
 
 
 
 
 
 
dc32163
 
eb710fe
c1cc7f8
 
 
 
 
0edf57c
41f3674
c1cc7f8
 
 
 
 
 
 
41f3674
eb710fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e642d4
eb710fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436eef9
eb710fe
436eef9
eb710fe
 
 
 
 
 
 
79f8d64
eb710fe
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import sys
import os 
import tqdm
sys.path.append(os.path.abspath(os.path.join("", "..")))
import torch
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
from utils import load_models, save_model_w2w, save_model_for_diffusers
from sampling import sample_weights
from huggingface_hub import snapshot_download

global device
global generator 
global unet
global vae 
global text_encoder
global tokenizer
global noise_scheduler
device = "cuda:0"

models_path = snapshot_download(repo_id="Snapchat/w2w")

mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device)
std = torch.load(f"{models_path}/std.pt").bfloat16().to(device)
v = torch.load(f"{models_path}/V.pt").bfloat16().to(device)
proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device)
df = torch.load(f"{models_path}/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")

unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
#global network

def sample_model():
   global unet
   del unet
   global network
   unet, _, _, _, _ = load_models(device)
@torch.no_grad()
def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
    global device
    global generator 
    global unet
    global vae 
    global text_encoder
    global tokenizer
    global noise_scheduler
    generator = torch.Generator(device=device).manual_seed(seed)
    latents = torch.randn(
        (1, unet.in_channels, 512 // 8, 512 // 8),
        generator = generator,
        device = device
    ).bfloat16()
   
    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")

    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
                            [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
                        )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    noise_scheduler.set_timesteps(ddim_steps) 
    latents = latents * noise_scheduler.init_noise_sigma
    
    for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
        with network:
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
        #guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
    
    latents = 1 / 0.18215 * latents
    image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]

    image = Image.fromarray((image * 255).round().astype("uint8"))

    return [image] 

with gr.Blocks() as demo:
    gr.Markdown("# <em>weights2weights</em> Demo")
    with gr.Row():
        with gr.Column():
            files = gr.Files(
                        label="Upload a photo of your face to invert, or sample a new model",
                        file_types=["image"]
                    )
            uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)

            sample = gr.Button("Sample New Model")

            with gr.Column(visible=False) as clear_button:
                remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
            prompt = gr.Textbox(label="Prompt",
                       info="Make sure to include 'sks person'" ,
                       placeholder="sks person", 
                       value="sks person")
            negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
            seed = gr.Number(value=5, label="Seed", interactive=True)
            cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
            steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)


            submit = gr.Button("Submit")

        with gr.Column():
            gallery = gr.Gallery(label="Generated Images")

        sample.click(fn=sample_model)
        
        submit.click(fn=inference,
                    inputs=[prompt, negative_prompt, cfg, steps, seed],
                    outputs=gallery)
            
demo.launch(share=True)