amildravid4292 commited on
Commit
ab2f02b
·
verified ·
1 Parent(s): 9274044

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -23
app.py CHANGED
@@ -124,10 +124,13 @@ def sample_then_run(net):
124
  standev = torch.std(proj, 0)
125
 
126
  # sample
127
- sample = torch.zeros([1, 1000]).to(device)
 
 
128
  for i in range(1000):
129
  sample[0, i] = torch.normal(m[i], standev[i], (1,1))
130
 
 
131
  net = "model_"+str(uuid.uuid4())[:4]+".pt"
132
  torch.save(sample, net)
133
 
@@ -148,7 +151,7 @@ def inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
148
  v.to(device)
149
 
150
  weights = torch.load(net).to(device)
151
- network = LoRAw2w(weights, mean, std, v[:, :1000],
152
  unet,
153
  rank=1,
154
  multiplier=1.0,
@@ -215,7 +218,7 @@ def edit_inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, see
215
 
216
 
217
  weights = torch.load(net).to(device)
218
- network = LoRAw2w(weights, mean, std, v[:, :1000],
219
  unet,
220
  rank=1,
221
  multiplier=1.0,
@@ -386,32 +389,25 @@ def run_inversion(self, dict, pcs, epochs, weight_decay,lr):
386
 
387
 
388
  @spaces.GPU
389
- def file_upload(self, file):
390
- proj = torch.load(file.name).to(device)
 
 
 
391
 
392
  #pad to 10000 Principal components to keep everything consistent
393
- pcs = proj.shape[1]
394
  padding = torch.zeros((1,10000-pcs)).to(device)
395
- proj = torch.cat((proj, padding), 1)
396
- unet, _, _, _, _ = load_models(device)
397
-
398
 
399
- network = LoRAw2w( proj, mean, std, v[:, :10000],
400
- unet,
401
- rank=1,
402
- multiplier=1.0,
403
- alpha=27.0,
404
- train_method="xattn-strict"
405
- ).to(device, torch.bfloat16)
406
-
407
-
408
- prompt = "sks person"
409
- negative_prompt = "low quality, blurry, unfinished, nudity"
410
  seed = 5
411
  cfg = 3.0
412
  steps = 25
413
- image = inference( prompt, negative_prompt, cfg, steps, seed)
414
- return image
415
 
416
 
417
 
@@ -516,7 +512,7 @@ with gr.Blocks(css="style.css") as demo:
516
  submit.click(
517
  fn=edit_inference, inputs=[net, prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[net, gallery]
518
  )
519
- # file_input.change(fn=model.file_upload, inputs=file_input, outputs = gallery)
520
 
521
 
522
 
 
124
  standev = torch.std(proj, 0)
125
 
126
  # sample
127
+ sample = torch.zeros([1, 10000]).to(device)
128
+
129
+ #only first 1000 PCs
130
  for i in range(1000):
131
  sample[0, i] = torch.normal(m[i], standev[i], (1,1))
132
 
133
+
134
  net = "model_"+str(uuid.uuid4())[:4]+".pt"
135
  torch.save(sample, net)
136
 
 
151
  v.to(device)
152
 
153
  weights = torch.load(net).to(device)
154
+ network = LoRAw2w(weights, mean, std, v[:, :10000],
155
  unet,
156
  rank=1,
157
  multiplier=1.0,
 
218
 
219
 
220
  weights = torch.load(net).to(device)
221
+ network = LoRAw2w(weights, mean, std, v[:, :10000],
222
  unet,
223
  rank=1,
224
  multiplier=1.0,
 
389
 
390
 
391
  @spaces.GPU
392
+ def file_upload(file):
393
+ device="cuda"
394
+ weights = torch.load(file.name).to(device)
395
+ net = "model_"+str(uuid.uuid4())[:4]+".pt"
396
+ torch.save(weights, net)
397
 
398
  #pad to 10000 Principal components to keep everything consistent
399
+ pcs = net.shape[1]
400
  padding = torch.zeros((1,10000-pcs)).to(device)
401
+ weights = torch.cat((weights, padding), 1)
402
+
 
403
 
404
+ image = prompt = "sks person"
405
+ negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
 
 
 
 
 
 
 
 
 
406
  seed = 5
407
  cfg = 3.0
408
  steps = 25
409
+ image = inference(net, prompt, negative_prompt, cfg, steps, seed)
410
+ return net,net,image
411
 
412
 
413
 
 
512
  submit.click(
513
  fn=edit_inference, inputs=[net, prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[net, gallery]
514
  )
515
+ file_input.change(fn=file_upload, inputs=[file_input, net], outputs = [net, gallery])
516
 
517
 
518