File size: 3,764 Bytes
2cd4b2a
c15417b
 
 
 
eb710fe
 
 
6b8abea
eb710fe
 
 
 
 
86ffd66
dfca074
9f713c2
c15417b
99e4caa
 
 
 
 
 
 
 
 
 
 
 
 
eb710fe
3b7acd7
31081a5
76ee786
ac24ff3
8fdcb49
c2136ba
173dd22
 
c2136ba
 
 
 
 
99e4caa
 
c2136ba
 
 
99e4caa
 
c2136ba
99e4caa
 
c2136ba
 
99e4caa
 
c2136ba
 
 
76ee786
c2136ba
 
 
 
 
99172cd
c2136ba
6b8abea
 
 
 
 
 
 
621123f
c2136ba
621123f
31081a5
6b8abea
de7b6ff
 
6b8abea
 
 
f15739b
6b8abea
 
 
8ead4f0
31081a5
6b8abea
31081a5
 
6b8abea
 
99e4caa
 
6b8abea
76ee786
 
 
 
99e4caa
6b8abea
99e4caa
12b3d57
99e4caa
6b8abea
99e4caa
 
 
76ee786
99e4caa
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import gradio as gr
import sys
import tqdm
import uuid
sys.path.append(os.path.abspath(os.path.join("", "..")))
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
import numpy as np
from editing import get_direction, debias
from sampling import sample_weights
from lora_w2w import LoRAw2w
from transformers import CLIPTextModel
from lora_w2w import LoRAw2w
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
from transformers import AutoTokenizer, PretrainedConfig
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DiffusionPipeline,
    DPMSolverMultistepScheduler,
    UNet2DConditionModel,
    PNDMScheduler, 
    StableDiffusionPipeline
)
from huggingface_hub import snapshot_download
import spaces

models_path = snapshot_download(repo_id="Snapchat/w2w")




device="cuda"
pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51" 
revision = None
weight_dtype = torch.bfloat16
# Load scheduler, tokenizer and models.
pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51", 
                                                torch_dtype=torch.float16,safety_checker = None,
                                                requires_safety_checker = False).to(device)
noise_scheduler = pipe.scheduler
del pipe
tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
        )
text_encoder = CLIPTextModel.from_pretrained(
            pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
        )
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
unet = UNet2DConditionModel.from_pretrained(
            pretrained_model_name_or_path, subfolder="unet", revision=revision
        )
unet.requires_grad_(False)
unet.to(device, dtype=weight_dtype)
vae.requires_grad_(False)
    
text_encoder.requires_grad_(False)
vae.requires_grad_(False)
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)
print("")

    
mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device)
std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device)
v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device)
proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
df = torch.load(f"{models_path}/files/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)




@spaces.GPU
def sample_then_run(net):
    print(net)
    # get mean and standard deviation for each principal component
    m = torch.mean(proj, 0)
    standev = torch.std(proj, 0)
    
    # sample
    sample = torch.zeros([1, 1000]).to(device)
    for i in range(1000):
        sample[0, i] = torch.normal(m[i], standev[i], (1,1))

    net = "model_"+str(uuid.uuid4())[:4]+".pt"


    return net

    
with gr.Blocks(css="style.css") as demo:
    net = gr.State()     
    with gr.Column():
            with gr.Row():
                with gr.Column(): 
                    sample = gr.Button("🎲 Sample New Model")
                    
                    
                
    sample.click(fn=sample_then_run, inputs = [net], outputs=[net])
            
        
    
    
        
demo.queue().launch()