Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -125,7 +125,7 @@ class main():
|
|
125 |
print("")
|
126 |
|
127 |
|
128 |
-
self.
|
129 |
|
130 |
young = get_direction(df, "Young", pinverse, 1000, device)
|
131 |
young = debias(young, "Male", df, pinverse, device)
|
@@ -170,11 +170,7 @@ class main():
|
|
170 |
self.thick = thick
|
171 |
|
172 |
|
173 |
-
|
174 |
-
@spaces.GPU(duration=1000)
|
175 |
-
def sample_model(self):
|
176 |
-
self.unet, _, _, _, _ = load_models(self.device)
|
177 |
-
self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
|
178 |
|
179 |
|
180 |
@torch.no_grad()
|
@@ -184,8 +180,19 @@ class main():
|
|
184 |
self.unet.to(device)
|
185 |
self.text_encoder.to(device)
|
186 |
self.vae.to(device)
|
187 |
-
self.
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
|
191 |
|
@@ -213,18 +220,9 @@ class main():
|
|
213 |
for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
|
214 |
latent_model_input = torch.cat([latents] * 2)
|
215 |
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
|
216 |
-
with
|
217 |
-
print(latent_model_input.device)
|
218 |
-
print(self.unet.device)
|
219 |
-
print(self.text_encoder.device)
|
220 |
-
print(self.vae.device)
|
221 |
-
print(self.network.proj.device)
|
222 |
-
print(self.network.mean.device)
|
223 |
-
print(self.network.std.device)
|
224 |
-
print(self.network.v.device)
|
225 |
-
print(text_embeddings.device)
|
226 |
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
|
227 |
-
|
228 |
#guidance
|
229 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
230 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
@@ -315,16 +313,22 @@ class main():
|
|
315 |
|
316 |
return image
|
317 |
|
318 |
-
@
|
|
|
319 |
def sample_then_run(self):
|
320 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
321 |
prompt = "sks person"
|
322 |
negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
|
323 |
seed = 5
|
324 |
cfg = 3.0
|
325 |
steps = 25
|
326 |
-
image = self.inference( prompt, negative_prompt, cfg, steps, seed)
|
327 |
-
torch.save(self.
|
328 |
return image, "model.pt"
|
329 |
|
330 |
|
|
|
125 |
print("")
|
126 |
|
127 |
|
128 |
+
self.weights = None
|
129 |
|
130 |
young = get_direction(df, "Young", pinverse, 1000, device)
|
131 |
young = debias(young, "Male", df, pinverse, device)
|
|
|
170 |
self.thick = thick
|
171 |
|
172 |
|
173 |
+
|
|
|
|
|
|
|
|
|
174 |
|
175 |
|
176 |
@torch.no_grad()
|
|
|
180 |
self.unet.to(device)
|
181 |
self.text_encoder.to(device)
|
182 |
self.vae.to(device)
|
183 |
+
self.mean.to(device)
|
184 |
+
self.std.to(device)
|
185 |
+
self.v.to(device)
|
186 |
+
self.proj.to(device)
|
187 |
+
self.weights.to(device)
|
188 |
+
|
189 |
+
network = LoRAw2w( self.weights, self.mean, self.std, self.v,
|
190 |
+
self.unet,
|
191 |
+
rank=1,
|
192 |
+
multiplier=1.0,
|
193 |
+
alpha=27.0,
|
194 |
+
train_method="xattn-strict"
|
195 |
+
).to(device, torch.bfloat16)
|
196 |
|
197 |
|
198 |
|
|
|
220 |
for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
|
221 |
latent_model_input = torch.cat([latents] * 2)
|
222 |
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
|
223 |
+
with network:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
|
225 |
+
|
226 |
#guidance
|
227 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
228 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
313 |
|
314 |
return image
|
315 |
|
316 |
+
@torch.no_grad()
|
317 |
+
@spaces.GPU(duration=1000)
|
318 |
def sample_then_run(self):
|
319 |
+
self.unet = UNet2DConditionModel.from_pretrained(
|
320 |
+
pretrained_model_name_or_path, subfolder="unet", revision=revision
|
321 |
+
)
|
322 |
+
self.unet.to(self.device, dtype=torch.bfloat16)
|
323 |
+
self.weights = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
|
324 |
+
|
325 |
prompt = "sks person"
|
326 |
negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
|
327 |
seed = 5
|
328 |
cfg = 3.0
|
329 |
steps = 25
|
330 |
+
image = self.inference( weights, prompt, negative_prompt, cfg, steps, seed)
|
331 |
+
torch.save(self.weights.cpu().detach(), "model.pt" )
|
332 |
return image, "model.pt"
|
333 |
|
334 |
|