Spaces:
Sleeping
Sleeping
File size: 4,011 Bytes
2cd4b2a c15417b eb710fe 6b8abea eb710fe 86ffd66 dfca074 9f713c2 c15417b 99e4caa eb710fe 3b7acd7 31081a5 76ee786 ac24ff3 8fdcb49 4e3e10a 99e4caa 18b14c9 99e4caa 2d4497e 99e4caa 76ee786 99e4caa 31081a5 99e4caa 31081a5 3b7acd7 31081a5 99172cd 86ffd66 6b8abea 621123f 6b8abea 621123f 31081a5 6b8abea f15739b 6b8abea 31081a5 6b8abea 31081a5 6b8abea 99e4caa 6b8abea 76ee786 99e4caa 6b8abea 99e4caa 12b3d57 99e4caa 6b8abea 99e4caa 76ee786 99e4caa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import gradio as gr
import sys
import tqdm
import uuid
sys.path.append(os.path.abspath(os.path.join("", "..")))
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
import numpy as np
from editing import get_direction, debias
from sampling import sample_weights
from lora_w2w import LoRAw2w
from transformers import CLIPTextModel
from lora_w2w import LoRAw2w
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
from transformers import AutoTokenizer, PretrainedConfig
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
UNet2DConditionModel,
PNDMScheduler,
StableDiffusionPipeline
)
from huggingface_hub import snapshot_download
import spaces
models_path = snapshot_download(repo_id="Snapchat/w2w")
@spaces.GPU
def load_models(device):
pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51"
revision = None
weight_dtype = torch.bfloat16
# Load scheduler, tokenizer and models.
pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51",
torch_dtype=torch.float16,safety_checker = None,
requires_safety_checker = False).to(device)
noise_scheduler = pipe.scheduler
del pipe
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
)
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet", revision=revision
)
unet.requires_grad_(False)
unet.to(device, dtype=weight_dtype)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
vae.requires_grad_(False)
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)
print("")
return unet, vae, text_encoder, tokenizer, noise_scheduler
device="cuda"
mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device)
std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device)
v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device)
proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
df = torch.load(f"{models_path}/files/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
@spaces.GPU
def sample_then_run():
# get mean and standard deviation for each principal component
m = torch.mean(proj, 0)
standev = torch.std(proj, 0)
# sample
sample = torch.zeros([1, 1000]).to(device)
for i in range(1000):
sample[0, i] = torch.normal(m[i], factor*standev[i], (1,1))
net = "model_"+str(uuid.uuid4())[:4]+".pt"
return net
with gr.Blocks(css="style.css") as demo:
net = gr.State()
with gr.Column():
with gr.Row():
with gr.Column():
sample = gr.Button("🎲 Sample New Model")
sample.click(fn=sample_then_run, inputs = [net], outputs=[net])
demo.queue().launch()
|