amildravid4292 commited on
Commit
51836fc
·
verified ·
1 Parent(s): 25dcce1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -239
app.py CHANGED
@@ -5,6 +5,7 @@ import torchvision.transforms as transforms
5
  from torch.utils.data import Dataset, DataLoader
6
  import gradio as gr
7
  import sys
 
8
  import tqdm
9
  sys.path.append(os.path.abspath(os.path.join("", "..")))
10
  import gc
@@ -31,162 +32,102 @@ from diffusers import (
31
  from huggingface_hub import snapshot_download
32
  import spaces
33
 
 
34
  models_path = snapshot_download(repo_id="Snapchat/w2w")
35
 
36
 
37
- @spaces.GPU
38
- def load_models(device):
39
- pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51"
40
-
41
- revision = None
42
- weight_dtype = torch.bfloat16
43
-
44
- # Load scheduler, tokenizer and models.
45
- pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51",
46
  torch_dtype=torch.float16,safety_checker = None,
47
  requires_safety_checker = False).to(device)
48
- noise_scheduler = pipe.scheduler
49
- del pipe
50
- tokenizer = AutoTokenizer.from_pretrained(
51
  pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
52
  )
53
- text_encoder = CLIPTextModel.from_pretrained(
54
  pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
55
  )
56
- vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
57
- unet = UNet2DConditionModel.from_pretrained(
58
  pretrained_model_name_or_path, subfolder="unet", revision=revision
59
  )
60
- unet.requires_grad_(False)
61
- unet.to(device, dtype=weight_dtype)
62
- vae.requires_grad_(False)
63
-
64
- text_encoder.requires_grad_(False)
65
- vae.requires_grad_(False)
66
- vae.to(device, dtype=weight_dtype)
67
- text_encoder.to(device, dtype=weight_dtype)
68
- print("")
69
 
70
- return unet, vae, text_encoder, tokenizer, noise_scheduler
71
 
72
- class main():
73
- def __init__(self):
74
- super(main, self).__init__()
75
-
76
- device = "cuda"
77
- mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device)
78
- std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device)
79
- v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device)
80
- proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
81
- df = torch.load(f"{models_path}/files/identity_df.pt")
82
- weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
83
- pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
84
 
85
- self.device = device
86
- self.mean = mean
87
- self.std = std
88
- self.v = v
89
- self.proj = proj
90
- self.df = df
91
- self.weight_dimensions = weight_dimensions
92
- self.pinverse = pinverse
93
-
94
- pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51"
95
-
96
- revision = None
97
- rank = 1
98
- weight_dtype = torch.bfloat16
99
-
100
- # Load scheduler, tokenizer and models.
101
- pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51",
102
- torch_dtype=torch.float16,safety_checker = None,
103
- requires_safety_checker = False).to(device)
104
- self.noise_scheduler = pipe.scheduler
105
- del pipe
106
- self.tokenizer = AutoTokenizer.from_pretrained(
107
- pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
108
- )
109
- self.text_encoder = CLIPTextModel.from_pretrained(
110
- pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
111
- )
112
- self.vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
113
- self.unet = UNet2DConditionModel.from_pretrained(
114
- pretrained_model_name_or_path, subfolder="unet", revision=revision
115
- )
116
-
117
- self.unet.requires_grad_(False)
118
- self.unet.to(device, dtype=weight_dtype)
119
- self.vae.requires_grad_(False)
120
 
121
- self.text_encoder.requires_grad_(False)
122
- self.vae.requires_grad_(False)
123
- self.vae.to(device, dtype=weight_dtype)
124
- self.text_encoder.to(device, dtype=weight_dtype)
125
- print("")
126
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- self.weights = None
129
-
130
- young = get_direction(df, "Young", pinverse, 1000, device)
131
- young = debias(young, "Male", df, pinverse, device)
132
- young = debias(young, "Pointy_Nose", df, pinverse, device)
133
- young = debias(young, "Wavy_Hair", df, pinverse, device)
134
- young = debias(young, "Chubby", df, pinverse, device)
135
- young = debias(young, "No_Beard", df, pinverse, device)
136
- young = debias(young, "Mustache", df, pinverse, device)
137
- self.young = young
138
-
139
- pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
140
- pointy = debias(pointy, "Young", df, pinverse, device)
141
- pointy = debias(pointy, "Male", df, pinverse, device)
142
- pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
143
- pointy = debias(pointy, "Chubby", df, pinverse, device)
144
- pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
145
- self.pointy = pointy
146
-
147
- wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
148
- wavy = debias(wavy, "Young", df, pinverse, device)
149
- wavy = debias(wavy, "Male", df, pinverse, device)
150
- wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
151
- wavy = debias(wavy, "Chubby", df, pinverse, device)
152
- wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
153
- self.wavy = wavy
154
-
155
-
156
- thick = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device)
157
- thick = debias(thick, "Male", df, pinverse, device)
158
- thick = debias(thick, "Young", df, pinverse, device)
159
- thick = debias(thick, "Pointy_Nose", df, pinverse, device)
160
- thick = debias(thick, "Wavy_Hair", df, pinverse, device)
161
- thick = debias(thick, "Mustache", df, pinverse, device)
162
- thick = debias(thick, "No_Beard", df, pinverse, device)
163
- thick = debias(thick, "Sideburns", df, pinverse, device)
164
- thick = debias(thick, "Big_Nose", df, pinverse, device)
165
- thick = debias(thick, "Big_Lips", df, pinverse, device)
166
- thick = debias(thick, "Black_Hair", df, pinverse, device)
167
- thick = debias(thick, "Brown_Hair", df, pinverse, device)
168
- thick = debias(thick, "Pale_Skin", df, pinverse, device)
169
- thick = debias(thick, "Heavy_Makeup", df, pinverse, device)
170
- self.thick = thick
171
-
172
 
173
-
174
-
175
-
176
- @torch.no_grad()
177
- @spaces.GPU(duration=120)
178
- def inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
179
- device = self.device
180
- self.unet.to(device)
181
- self.text_encoder.to(device)
182
- self.vae.to(device)
183
- self.mean.to(device)
184
- self.std.to(device)
185
- self.v.to(device)
186
- self.proj.to(device)
187
- self.weights.to(device)
188
 
189
- network = LoRAw2w( self.weights.bfloat16(), self.mean.bfloat16(), self.std.bfloat16(), self.v[:, :1000].bfloat16(),
190
  self.unet,
191
  rank=1,
192
  multiplier=1.0,
@@ -196,68 +137,67 @@ class main():
196
 
197
 
198
 
199
- generator = torch.Generator(device=device).manual_seed(seed)
200
- latents = torch.randn(
201
  (1, self.unet.in_channels, 512 // 8, 512 // 8),
202
  generator = generator,
203
  device = self.device
204
  ).bfloat16()
205
 
206
 
207
- text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
208
 
209
- text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
210
 
211
- max_length = text_input.input_ids.shape[-1]
212
- uncond_input = self.tokenizer(
213
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
214
  )
215
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
216
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16()
217
- self.noise_scheduler.set_timesteps(ddim_steps)
218
- latents = latents * self.noise_scheduler.init_noise_sigma
219
 
220
- for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
221
- latent_model_input = torch.cat([latents] * 2)
222
- latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
223
 
224
- with network:
225
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
226
 
227
- #guidance
228
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
229
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
230
- latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
231
 
232
- latents = 1 / 0.18215 * latents
233
- image = self.vae.decode(latents.float()).sample
234
- image = (image / 2 + 0.5).clamp(0, 1)
235
- image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
236
 
237
- image = Image.fromarray((image * 255).round().astype("uint8"))
238
 
239
- return image
240
 
241
 
242
- @torch.no_grad()
243
- @spaces.GPU(duration=120)
244
- def edit_inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
245
- print("start")
246
- device = self.device
247
- self.unet.to(device)
248
- self.text_encoder.to(device)
249
- self.vae.to(device)
250
- self.mean.to(device)
251
- self.std.to(device)
252
- self.v.to(device)
253
- self.proj.to(device)
254
- self.weights = torch.load("model.pt").to(device)
255
- self.young.to(device)
256
- self.pointy.to(device)
257
- self.wavy.to(device)
258
- self.thick.to(device)
259
 
260
- network = LoRAw2w( self.weights.bfloat16(), self.mean.bfloat16(), self.std.bfloat16(), self.v[:, :1000].bfloat16(),
261
  self.unet,
262
  rank=1,
263
  multiplier=1.0,
@@ -266,90 +206,87 @@ class main():
266
  ).to(device, torch.bfloat16)
267
 
268
 
269
- original_weights = self.weights.clone()
270
 
271
- #pad to same number of PCs
272
- pcs_original = original_weights.shape[1]
273
- pcs_edits = self.young.shape[1]
274
- padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
275
- young_pad = torch.cat((self.young, padding), 1)
276
- pointy_pad = torch.cat((self.pointy, padding), 1)
277
- wavy_pad = torch.cat((self.wavy, padding), 1)
278
- thick_pad = torch.cat((self.thick, padding), 1)
279
 
280
 
281
- edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
282
 
283
- generator = torch.Generator(device=device).manual_seed(seed)
284
- latents = torch.randn(
285
  (1, self.unet.in_channels, 512 // 8, 512 // 8),
286
  generator = generator,
287
  device = self.device
288
  ).bfloat16()
289
 
290
 
291
- text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
292
 
293
- text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
294
 
295
- max_length = text_input.input_ids.shape[-1]
296
- uncond_input = self.tokenizer(
297
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
298
  )
299
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
300
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16()
301
- self.noise_scheduler.set_timesteps(ddim_steps)
302
- latents = latents * self.noise_scheduler.init_noise_sigma
303
 
304
 
305
 
306
- for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
307
- latent_model_input = torch.cat([latents] * 2)
308
- latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
309
 
310
- if t>start_noise:
311
- pass
312
- elif t<=start_noise:
313
- network.proj = torch.nn.Parameter(edited_weights)
314
- network.reset()
315
 
316
- with network:
317
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
318
 
319
 
320
- #guidance
321
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
322
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
323
- latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
324
 
325
- latents = 1 / 0.18215 * latents
326
- image = self.vae.decode(latents.float()).sample
327
- image = (image / 2 + 0.5).clamp(0, 1)
328
-
329
- image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
330
-
331
- image = Image.fromarray((image * 255).round().astype("uint8"))
332
 
333
-
334
- return image
335
 
336
- @torch.no_grad()
337
- @spaces.GPU(duration=120)
338
- def sample_then_run(self):
339
- self.unet = UNet2DConditionModel.from_pretrained(
340
  "stablediffusionapi/realistic-vision-v51" , subfolder="unet", revision=None
341
  )
342
- self.unet.to(self.device, dtype=torch.bfloat16)
343
- self.weights = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
344
-
345
- prompt = "sks person"
346
- negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
347
- seed = 5
348
- cfg = 3.0
349
- steps = 25
350
- image = self.inference(prompt, negative_prompt, cfg, steps, seed)
351
- torch.save(self.weights.cpu().detach(), "model.pt" )
352
- return image, "model.pt"
353
 
354
 
355
 
 
5
  from torch.utils.data import Dataset, DataLoader
6
  import gradio as gr
7
  import sys
8
+ import uuid
9
  import tqdm
10
  sys.path.append(os.path.abspath(os.path.join("", "..")))
11
  import gc
 
32
  from huggingface_hub import snapshot_download
33
  import spaces
34
 
35
+
36
  models_path = snapshot_download(repo_id="Snapchat/w2w")
37
 
38
 
39
+ device = "cuda"
40
+ pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51"
41
+ revision = None
42
+ weight_dtype = torch.bfloat16
43
+ # Load scheduler, tokenizer and models.
44
+ pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51",
 
 
 
45
  torch_dtype=torch.float16,safety_checker = None,
46
  requires_safety_checker = False).to(device)
47
+ noise_scheduler = pipe.scheduler
48
+ del pipe
49
+ tokenizer = AutoTokenizer.from_pretrained(
50
  pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
51
  )
52
+ text_encoder = CLIPTextModel.from_pretrained(
53
  pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
54
  )
55
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
56
+ unet = UNet2DConditionModel.from_pretrained(
57
  pretrained_model_name_or_path, subfolder="unet", revision=revision
58
  )
59
+ unet.requires_grad_(False)
60
+ unet.to(device, dtype=weight_dtype)
61
+ vae.requires_grad_(False)
62
+
63
+ text_encoder.requires_grad_(False)
64
+ vae.requires_grad_(False)
65
+ vae.to(device, dtype=weight_dtype)
66
+ text_encoder.to(device, dtype=weight_dtype)
67
+ print("")
68
 
 
69
 
70
+ mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device)
71
+ std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device)
72
+ v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device)
73
+ proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
74
+ df = torch.load(f"{models_path}/files/identity_df.pt")
75
+ weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
76
+ pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
 
 
 
 
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ young = get_direction(df, "Young", pinverse, 1000, device)
80
+ young = debias(young, "Male", df, pinverse, device)
81
+ young = debias(young, "Pointy_Nose", df, pinverse, device)
82
+ young = debias(young, "Wavy_Hair", df, pinverse, device)
83
+ young = debias(young, "Chubby", df, pinverse, device)
84
+ young = debias(young, "No_Beard", df, pinverse, device)
85
+ young = debias(young, "Mustache", df, pinverse, device)
86
+
87
+ pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
88
+ pointy = debias(pointy, "Young", df, pinverse, device)
89
+ pointy = debias(pointy, "Male", df, pinverse, device)
90
+ pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
91
+ pointy = debias(pointy, "Chubby", df, pinverse, device)
92
+ pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
93
+
94
+ wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
95
+ wavy = debias(wavy, "Young", df, pinverse, device)
96
+ wavy = debias(wavy, "Male", df, pinverse, device)
97
+ wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
98
+ wavy = debias(wavy, "Chubby", df, pinverse, device)
99
+ wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
100
+
101
+ thick = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device)
102
+ thick = debias(thick, "Male", df, pinverse, device)
103
+ thick = debias(thick, "Young", df, pinverse, device)
104
+ thick = debias(thick, "Pointy_Nose", df, pinverse, device)
105
+ thick = debias(thick, "Wavy_Hair", df, pinverse, device)
106
+ thick = debias(thick, "Mustache", df, pinverse, device)
107
+ thick = debias(thick, "No_Beard", df, pinverse, device)
108
+ thick = debias(thick, "Sideburns", df, pinverse, device)
109
+ thick = debias(thick, "Big_Nose", df, pinverse, device)
110
+ thick = debias(thick, "Big_Lips", df, pinverse, device)
111
+ thick = debias(thick, "Black_Hair", df, pinverse, device)
112
+ thick = debias(thick, "Brown_Hair", df, pinverse, device)
113
+ thick = debias(thick, "Pale_Skin", df, pinverse, device)
114
+ thick = debias(thick, "Heavy_Makeup", df, pinverse, device)
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ @torch.no_grad()
118
+ @spaces.GPU(duration=120)
119
+ def inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
120
+ device = self.device
121
+ self.unet.to(device)
122
+ self.text_encoder.to(device)
123
+ self.vae.to(device)
124
+ self.mean.to(device)
125
+ self.std.to(device)
126
+ self.v.to(device)
127
+ self.proj.to(device)
128
+ self.weights.to(device)
 
 
 
129
 
130
+ network = LoRAw2w( self.weights.bfloat16(), self.mean.bfloat16(), self.std.bfloat16(), self.v[:, :1000].bfloat16(),
131
  self.unet,
132
  rank=1,
133
  multiplier=1.0,
 
137
 
138
 
139
 
140
+ generator = torch.Generator(device=device).manual_seed(seed)
141
+ latents = torch.randn(
142
  (1, self.unet.in_channels, 512 // 8, 512 // 8),
143
  generator = generator,
144
  device = self.device
145
  ).bfloat16()
146
 
147
 
148
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
149
 
150
+ text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
151
 
152
+ max_length = text_input.input_ids.shape[-1]
153
+ uncond_input = self.tokenizer(
154
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
155
  )
156
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
157
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16()
158
+ self.noise_scheduler.set_timesteps(ddim_steps)
159
+ latents = latents * self.noise_scheduler.init_noise_sigma
160
 
161
+ for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
162
+ latent_model_input = torch.cat([latents] * 2)
163
+ latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
164
 
165
+ with network:
166
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
167
 
168
+ #guidance
169
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
170
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
171
+ latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
172
 
173
+ latents = 1 / 0.18215 * latents
174
+ image = self.vae.decode(latents.float()).sample
175
+ image = (image / 2 + 0.5).clamp(0, 1)
176
+ image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
177
 
178
+ image = Image.fromarray((image * 255).round().astype("uint8"))
179
 
180
+ return image
181
 
182
 
183
+ @torch.no_grad()
184
+ @spaces.GPU(duration=120)
185
+ def edit_inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
186
+ device = self.device
187
+ self.unet.to(device)
188
+ self.text_encoder.to(device)
189
+ self.vae.to(device)
190
+ self.mean.to(device)
191
+ self.std.to(device)
192
+ self.v.to(device)
193
+ self.proj.to(device)
194
+ self.weights = torch.load("model.pt").to(device)
195
+ self.young.to(device)
196
+ self.pointy.to(device)
197
+ self.wavy.to(device)
198
+ self.thick.to(device)
 
199
 
200
+ network = LoRAw2w( self.weights.bfloat16(), self.mean.bfloat16(), self.std.bfloat16(), self.v[:, :1000].bfloat16(),
201
  self.unet,
202
  rank=1,
203
  multiplier=1.0,
 
206
  ).to(device, torch.bfloat16)
207
 
208
 
209
+ original_weights = self.weights.clone()
210
 
211
+ #pad to same number of PCs
212
+ pcs_original = original_weights.shape[1]
213
+ pcs_edits = self.young.shape[1]
214
+ padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
215
+ young_pad = torch.cat((self.young, padding), 1)
216
+ pointy_pad = torch.cat((self.pointy, padding), 1)
217
+ wavy_pad = torch.cat((self.wavy, padding), 1)
218
+ thick_pad = torch.cat((self.thick, padding), 1)
219
 
220
 
221
+ edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
222
 
223
+ generator = torch.Generator(device=device).manual_seed(seed)
224
+ latents = torch.randn(
225
  (1, self.unet.in_channels, 512 // 8, 512 // 8),
226
  generator = generator,
227
  device = self.device
228
  ).bfloat16()
229
 
230
 
231
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
232
 
233
+ text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
234
 
235
+ max_length = text_input.input_ids.shape[-1]
236
+ uncond_input = self.tokenizer(
237
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
238
  )
239
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
240
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16()
241
+ self.noise_scheduler.set_timesteps(ddim_steps)
242
+ latents = latents * self.noise_scheduler.init_noise_sigma
243
 
244
 
245
 
246
+ for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
247
+ latent_model_input = torch.cat([latents] * 2)
248
+ latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
249
 
250
+ if t>start_noise:
251
+ pass
252
+ elif t<=start_noise:
253
+ network.proj = torch.nn.Parameter(edited_weights)
254
+ network.reset()
255
 
256
+ with network:
257
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
258
 
259
 
260
+ #guidance
261
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
262
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
263
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
264
 
265
+ latents = 1 / 0.18215 * latents
266
+ image = self.vae.decode(latents.float()).sample
267
+ image = (image / 2 + 0.5).clamp(0, 1)
268
+ image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
269
+ image = Image.fromarray((image * 255).round().astype("uint8"))
 
 
270
 
271
+ return image
 
272
 
273
+ @torch.no_grad()
274
+ @spaces.GPU(duration=120)
275
+ def sample_then_run(self):
276
+ self.unet = UNet2DConditionModel.from_pretrained(
277
  "stablediffusionapi/realistic-vision-v51" , subfolder="unet", revision=None
278
  )
279
+ self.unet.to(self.device, dtype=torch.bfloat16)
280
+ self.weights = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
281
+
282
+ prompt = "sks person"
283
+ negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
284
+ seed = 5
285
+ cfg = 3.0
286
+ steps = 25
287
+ image = self.inference(prompt, negative_prompt, cfg, steps, seed)
288
+ torch.save(self.weights.cpu().detach(), "model.pt" )
289
+ return image, "model.pt"
290
 
291
 
292