amildravid4292 commited on
Commit
d56b2a9
·
verified ·
1 Parent(s): 5e37626

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -23
app.py CHANGED
@@ -125,7 +125,7 @@ class main():
125
  print("")
126
 
127
 
128
- self.network = None
129
 
130
  young = get_direction(df, "Young", pinverse, 1000, device)
131
  young = debias(young, "Male", df, pinverse, device)
@@ -170,11 +170,7 @@ class main():
170
  self.thick = thick
171
 
172
 
173
- @torch.no_grad()
174
- @spaces.GPU(duration=1000)
175
- def sample_model(self):
176
- self.unet, _, _, _, _ = load_models(self.device)
177
- self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
178
 
179
 
180
  @torch.no_grad()
@@ -184,8 +180,19 @@ class main():
184
  self.unet.to(device)
185
  self.text_encoder.to(device)
186
  self.vae.to(device)
187
- self.network.to(device)
188
-
 
 
 
 
 
 
 
 
 
 
 
189
 
190
 
191
 
@@ -213,18 +220,9 @@ class main():
213
  for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
214
  latent_model_input = torch.cat([latents] * 2)
215
  latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
216
- with self.network:
217
- print(latent_model_input.device)
218
- print(self.unet.device)
219
- print(self.text_encoder.device)
220
- print(self.vae.device)
221
- print(self.network.proj.device)
222
- print(self.network.mean.device)
223
- print(self.network.std.device)
224
- print(self.network.v.device)
225
- print(text_embeddings.device)
226
  noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
227
- print("after inference")
228
  #guidance
229
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
230
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
@@ -315,16 +313,22 @@ class main():
315
 
316
  return image
317
 
318
- @spaces.GPU
 
319
  def sample_then_run(self):
320
- self.sample_model()
 
 
 
 
 
321
  prompt = "sks person"
322
  negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
323
  seed = 5
324
  cfg = 3.0
325
  steps = 25
326
- image = self.inference( prompt, negative_prompt, cfg, steps, seed)
327
- torch.save(self.network.proj, "model.pt" )
328
  return image, "model.pt"
329
 
330
 
 
125
  print("")
126
 
127
 
128
+ self.weights = None
129
 
130
  young = get_direction(df, "Young", pinverse, 1000, device)
131
  young = debias(young, "Male", df, pinverse, device)
 
170
  self.thick = thick
171
 
172
 
173
+
 
 
 
 
174
 
175
 
176
  @torch.no_grad()
 
180
  self.unet.to(device)
181
  self.text_encoder.to(device)
182
  self.vae.to(device)
183
+ self.mean.to(device)
184
+ self.std.to(device)
185
+ self.v.to(device)
186
+ self.proj.to(device)
187
+ self.weights.to(device)
188
+
189
+ network = LoRAw2w( self.weights, self.mean, self.std, self.v,
190
+ self.unet,
191
+ rank=1,
192
+ multiplier=1.0,
193
+ alpha=27.0,
194
+ train_method="xattn-strict"
195
+ ).to(device, torch.bfloat16)
196
 
197
 
198
 
 
220
  for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
221
  latent_model_input = torch.cat([latents] * 2)
222
  latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
223
+ with network:
 
 
 
 
 
 
 
 
 
224
  noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
225
+
226
  #guidance
227
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
228
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
313
 
314
  return image
315
 
316
+ @torch.no_grad()
317
+ @spaces.GPU(duration=1000)
318
  def sample_then_run(self):
319
+ self.unet = UNet2DConditionModel.from_pretrained(
320
+ pretrained_model_name_or_path, subfolder="unet", revision=revision
321
+ )
322
+ self.unet.to(self.device, dtype=torch.bfloat16)
323
+ self.weights = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
324
+
325
  prompt = "sks person"
326
  negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
327
  seed = 5
328
  cfg = 3.0
329
  steps = 25
330
+ image = self.inference( weights, prompt, negative_prompt, cfg, steps, seed)
331
+ torch.save(self.weights.cpu().detach(), "model.pt" )
332
  return image, "model.pt"
333
 
334