File size: 4,011 Bytes
2cd4b2a
c15417b
 
 
 
eb710fe
 
 
6b8abea
eb710fe
 
 
 
 
86ffd66
dfca074
9f713c2
c15417b
99e4caa
 
 
 
 
 
 
 
 
 
 
 
 
eb710fe
3b7acd7
31081a5
76ee786
ac24ff3
8fdcb49
4e3e10a
99e4caa
 
18b14c9
99e4caa
 
2d4497e
99e4caa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76ee786
99e4caa
 
 
 
 
31081a5
99e4caa
31081a5
3b7acd7
31081a5
99172cd
86ffd66
6b8abea
 
 
 
 
 
 
 
621123f
6b8abea
621123f
31081a5
6b8abea
 
 
 
 
f15739b
6b8abea
 
 
 
31081a5
6b8abea
31081a5
 
6b8abea
 
99e4caa
 
6b8abea
76ee786
 
 
 
99e4caa
6b8abea
99e4caa
12b3d57
99e4caa
6b8abea
99e4caa
 
 
76ee786
99e4caa
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import gradio as gr
import sys
import tqdm
import uuid
sys.path.append(os.path.abspath(os.path.join("", "..")))
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
import numpy as np
from editing import get_direction, debias
from sampling import sample_weights
from lora_w2w import LoRAw2w
from transformers import CLIPTextModel
from lora_w2w import LoRAw2w
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
from transformers import AutoTokenizer, PretrainedConfig
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DiffusionPipeline,
    DPMSolverMultistepScheduler,
    UNet2DConditionModel,
    PNDMScheduler, 
    StableDiffusionPipeline
)
from huggingface_hub import snapshot_download
import spaces

models_path = snapshot_download(repo_id="Snapchat/w2w")


@spaces.GPU
def load_models(device):
    pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51" 
    
    revision = None
    weight_dtype = torch.bfloat16
    
    # Load scheduler, tokenizer and models.
    pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51", 
                                                torch_dtype=torch.float16,safety_checker = None,
                                                requires_safety_checker = False).to(device)
    noise_scheduler = pipe.scheduler
    del pipe
    tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
        )
    text_encoder = CLIPTextModel.from_pretrained(
            pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
        )
    vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
    unet = UNet2DConditionModel.from_pretrained(
            pretrained_model_name_or_path, subfolder="unet", revision=revision
        )
    unet.requires_grad_(False)
    unet.to(device, dtype=weight_dtype)
    vae.requires_grad_(False)
    
    text_encoder.requires_grad_(False)
    vae.requires_grad_(False)
    vae.to(device, dtype=weight_dtype)
    text_encoder.to(device, dtype=weight_dtype)
    print("")

    return unet, vae, text_encoder, tokenizer, noise_scheduler





device="cuda"
mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device)
std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device)
v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device)
proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
df = torch.load(f"{models_path}/files/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)

unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)


@spaces.GPU
def sample_then_run(): 
    # get mean and standard deviation for each principal component
    m = torch.mean(proj, 0)
    standev = torch.std(proj, 0)
    
    # sample
    sample = torch.zeros([1, 1000]).to(device)
    for i in range(1000):
        sample[0, i] = torch.normal(m[i], factor*standev[i], (1,1))

    net = "model_"+str(uuid.uuid4())[:4]+".pt"


    return net

    
with gr.Blocks(css="style.css") as demo:
    net = gr.State()     
    with gr.Column():
            with gr.Row():
                with gr.Column(): 
                    sample = gr.Button("🎲 Sample New Model")
                    
                    
                
    sample.click(fn=sample_then_run, inputs = [net], outputs=[net])
            
        
    
    
        
demo.queue().launch()