amildravid4292 commited on
Commit
ad7da82
·
verified ·
1 Parent(s): df9e08f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -102
app.py CHANGED
@@ -42,53 +42,85 @@ pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=to
42
 
43
  unet.value, vae.value, text_encoder.value, tokenizer.value, noise_scheduler.value = load_models(device.value)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def sample_model():
46
- unet.value, _, _, _, _ = load_models(device)
47
- network.value = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
48
 
49
  @torch.no_grad()
50
  @spaces.GPU
51
  def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
52
- global device
53
- #global generator
54
- global unet
55
- global vae
56
- global text_encoder
57
- global tokenizer
58
- global noise_scheduler
59
- generator = torch.Generator(device=device).manual_seed(seed)
60
  latents = torch.randn(
61
  (1, unet.in_channels, 512 // 8, 512 // 8),
62
  generator = generator,
63
- device = device
64
  ).bfloat16()
65
 
66
 
67
- text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
68
 
69
- text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
70
 
71
  max_length = text_input.input_ids.shape[-1]
72
- uncond_input = tokenizer(
73
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
74
  )
75
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
76
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
77
- noise_scheduler.set_timesteps(ddim_steps)
78
- latents = latents * noise_scheduler.init_noise_sigma
79
 
80
- for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
81
  latent_model_input = torch.cat([latents] * 2)
82
- latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
83
- with network:
84
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
85
  #guidance
86
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
87
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
88
  latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
89
 
90
  latents = 1 / 0.18215 * latents
91
- image = vae.decode(latents).sample
92
  image = (image / 2 + 0.5).clamp(0, 1)
93
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
94
 
@@ -100,78 +132,66 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
100
  @torch.no_grad()
101
  @spaces.GPU
102
  def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
103
- start_items()
104
- global device
105
- #global generator
106
- global unet
107
- global vae
108
- global text_encoder
109
- global tokenizer
110
- global noise_scheduler
111
- global young
112
- global pointy
113
- global wavy
114
- global thick
115
-
116
- original_weights = network.proj.clone()
117
 
118
  #pad to same number of PCs
119
  pcs_original = original_weights.shape[1]
120
- pcs_edits = young.shape[1]
121
  padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
122
- young_pad = torch.cat((young, padding), 1)
123
- pointy_pad = torch.cat((pointy, padding), 1)
124
- wavy_pad = torch.cat((wavy, padding), 1)
125
- thick_pad = torch.cat((thick, padding), 1)
126
 
127
 
128
  edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
129
 
130
- generator = torch.Generator(device=device).manual_seed(seed)
131
  latents = torch.randn(
132
  (1, unet.in_channels, 512 // 8, 512 // 8),
133
  generator = generator,
134
- device = device
135
  ).bfloat16()
136
 
137
 
138
- text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
139
 
140
- text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
141
 
142
  max_length = text_input.input_ids.shape[-1]
143
- uncond_input = tokenizer(
144
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
145
  )
146
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
147
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
148
- noise_scheduler.set_timesteps(ddim_steps)
149
- latents = latents * noise_scheduler.init_noise_sigma
150
 
151
 
152
 
153
- for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
154
  latent_model_input = torch.cat([latents] * 2)
155
- latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
156
 
157
  if t>start_noise:
158
  pass
159
  elif t<=start_noise:
160
- network.proj = torch.nn.Parameter(edited_weights)
161
- network.reset()
162
 
163
 
164
  with network:
165
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
166
 
167
 
168
  #guidance
169
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
170
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
171
- latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
172
 
173
  latents = 1 / 0.18215 * latents
174
- image = vae.decode(latents).sample
175
  image = (image / 2 + 0.5).clamp(0, 1)
176
 
177
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
@@ -179,8 +199,8 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
179
  image = Image.fromarray((image * 255).round().astype("uint8"))
180
 
181
  #reset weights back to original
182
- network.proj = torch.nn.Parameter(original_weights)
183
- network.reset()
184
 
185
  return image
186
 
@@ -193,52 +213,9 @@ def sample_then_run():
193
  cfg = 3.0
194
  steps = 25
195
  image = inference( prompt, negative_prompt, cfg, steps, seed)
196
- torch.save(network.proj, "model.pt" )
197
  return image, "model.pt"
198
 
199
- #@spaces.GPU
200
- def start_items():
201
- print("Starting items")
202
- global young
203
- global pointy
204
- global wavy
205
- global thick
206
- young = get_direction(df, "Young", pinverse, 1000, device)
207
- young = debias(young, "Male", df, pinverse, device)
208
- young = debias(young, "Pointy_Nose", df, pinverse, device)
209
- young = debias(young, "Wavy_Hair", df, pinverse, device)
210
- young = debias(young, "Chubby", df, pinverse, device)
211
- young = debias(young, "No_Beard", df, pinverse, device)
212
- young = debias(young, "Mustache", df, pinverse, device)
213
-
214
- pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
215
- pointy = debias(pointy, "Young", df, pinverse, device)
216
- pointy = debias(pointy, "Male", df, pinverse, device)
217
- pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
218
- pointy = debias(pointy, "Chubby", df, pinverse, device)
219
- pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
220
-
221
- wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
222
- wavy = debias(wavy, "Young", df, pinverse, device)
223
- wavy = debias(wavy, "Male", df, pinverse, device)
224
- wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
225
- wavy = debias(wavy, "Chubby", df, pinverse, device)
226
- wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
227
-
228
- thick = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device)
229
- thick = debias(thick, "Male", df, pinverse, device)
230
- thick = debias(thick, "Young", df, pinverse, device)
231
- thick = debias(thick, "Pointy_Nose", df, pinverse, device)
232
- thick = debias(thick, "Wavy_Hair", df, pinverse, device)
233
- thick = debias(thick, "Mustache", df, pinverse, device)
234
- thick = debias(thick, "No_Beard", df, pinverse, device)
235
- thick = debias(thick, "Sideburns", df, pinverse, device)
236
- thick = debias(thick, "Big_Nose", df, pinverse, device)
237
- thick = debias(thick, "Big_Lips", df, pinverse, device)
238
- thick = debias(thick, "Black_Hair", df, pinverse, device)
239
- thick = debias(thick, "Brown_Hair", df, pinverse, device)
240
- thick = debias(thick, "Pale_Skin", df, pinverse, device)
241
- thick = debias(thick, "Heavy_Makeup", df, pinverse, device)
242
 
243
  class CustomImageDataset(Dataset):
244
  def __init__(self, images, transform=None):
 
42
 
43
  unet.value, vae.value, text_encoder.value, tokenizer.value, noise_scheduler.value = load_models(device.value)
44
 
45
+
46
+ gr.State(young) = get_direction(df, "Young", pinverse, 1000, device.value)
47
+ young.value = debias(young.value, "Male", df, pinverse, device.value)
48
+ young.value = debias(young.value, "Pointy_Nose", df, pinverse, device.value)
49
+ young.value = debias(young.value, "Wavy_Hair", df, pinverse, device.value)
50
+ young.value = debias(young.value, "Chubby", df, pinverse, device.value)
51
+ young.value = debias(young.value, "No_Beard", df, pinverse, device.value)
52
+ young.value = debias(young.value, "Mustache", df, pinverse, device.value)
53
+
54
+ gr.State(pointy) = get_direction(df, "Pointy_Nose", pinverse, 1000, device.value)
55
+ pointy.value = debias(pointy.value, "Young", df, pinverse, device.value)
56
+ pointy.value = debias(pointy.value, "Male", df, pinverse, device.value)
57
+ pointy.value = debias(pointy.value, "Wavy_Hair", df, pinverse, device.value)
58
+ pointy.value = debias(pointy.value, "Chubby", df, pinverse, device.value)
59
+ pointy.value = debias(pointy.value, "Heavy_Makeup", df, pinverse, device.value)
60
+
61
+ gr.State(wavy) = get_direction(df, "Wavy_Hair", pinverse, 1000, device.value)
62
+ wavy.value = debias(wavy.value, "Young", df, pinverse, device.value)
63
+ wavy.value = debias(wavy.value, "Male", df, pinverse, device.value)
64
+ wavy.value = debias(wavy.value, "Pointy_Nose", df, pinverse, device.value)
65
+ wavy.value = debias(wavy.value, "Chubby", df, pinverse, device.value)
66
+ wavy.value = debias(wavy.value, "Heavy_Makeup", df, pinverse, device.value)
67
+
68
+ gr.State(thick) = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device.value)
69
+ thick.value = debias(thick.value, "Male", df, pinverse, device.value)
70
+ thick.value = debias(thick.value, "Young", df, pinverse, device.value)
71
+ thick.value = debias(thick.value, "Pointy_Nose", df, pinverse, device.value)
72
+ thick.value = debias(thick.value, "Wavy_Hair", df, pinverse, device.value)
73
+ thick.value = debias(thick.value, "Mustache", df, pinverse, device.value)
74
+ thick.value = debias(thick.value, "No_Beard", df, pinverse, device.value)
75
+ thick.value = debias(thick.value, "Sideburns", df, pinverse, device.value)
76
+ thick.value = debias(thick.value, "Big_Nose", df, pinverse, device.value)
77
+ thick.value = debias(thick.value, "Big_Lips", df, pinverse, device.value)
78
+ thick.value = debias(thick.value, "Black_Hair", df, pinverse, device.value)
79
+ thick.value = debias(thick.value, "Brown_Hair", df, pinverse, device.value)
80
+ thick.value = debias(thick.value, "Pale_Skin", df, pinverse, device.value)
81
+ thick.value = debias(thick.value, "Heavy_Makeup", df, pinverse, device.value)
82
+
83
  def sample_model():
84
+ unet.value, _, _, _, _ = load_models(device.value)
85
+ network.value = sample_weights(unet.value, proj.value, mean.value, std.value, v[:, :1000], device.value, factor = 1.00)
86
 
87
  @torch.no_grad()
88
  @spaces.GPU
89
  def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
90
+
91
+ generator = torch.Generator(device=device.value).manual_seed(seed)
 
 
 
 
 
 
92
  latents = torch.randn(
93
  (1, unet.in_channels, 512 // 8, 512 // 8),
94
  generator = generator,
95
+ device = device.value
96
  ).bfloat16()
97
 
98
 
99
+ text_input = tokenizer.value(prompt, padding="max_length", max_length=tokenizer.value.model_max_length, truncation=True, return_tensors="pt")
100
 
101
+ text_embeddings = text_encoder.value(text_input.input_ids.to(device))[0]
102
 
103
  max_length = text_input.input_ids.shape[-1]
104
+ uncond_input = tokenizer.value(
105
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
106
  )
107
+ uncond_embeddings = text_encoder.value(uncond_input.input_ids.to(device))[0]
108
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
109
+ noise_scheduler.value.set_timesteps(ddim_steps)
110
+ latents = latents * noise_scheduler.value.init_noise_sigma
111
 
112
+ for i,t in enumerate(tqdm.tqdm(noise_scheduler.value.timesteps)):
113
  latent_model_input = torch.cat([latents] * 2)
114
+ latent_model_input = noise_scheduler.value.scale_model_input(latent_model_input, timestep=t)
115
+ with network.value:
116
+ noise_pred = unet.value(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
117
  #guidance
118
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
119
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
120
  latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
121
 
122
  latents = 1 / 0.18215 * latents
123
+ image = vae.value.decode(latents).sample
124
  image = (image / 2 + 0.5).clamp(0, 1)
125
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
126
 
 
132
  @torch.no_grad()
133
  @spaces.GPU
134
  def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
135
+
136
+ original_weights = network.value.proj.clone()
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  #pad to same number of PCs
139
  pcs_original = original_weights.shape[1]
140
+ pcs_edits = young.value.shape[1]
141
  padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
142
+ young_pad = torch.cat((young.value, padding), 1)
143
+ pointy_pad = torch.cat((pointy.value, padding), 1)
144
+ wavy_pad = torch.cat((wavy.value, padding), 1)
145
+ thick_pad = torch.cat((thick.value, padding), 1)
146
 
147
 
148
  edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
149
 
150
+ generator = torch.Generator(device=device.value).manual_seed(seed)
151
  latents = torch.randn(
152
  (1, unet.in_channels, 512 // 8, 512 // 8),
153
  generator = generator,
154
+ device = device.value
155
  ).bfloat16()
156
 
157
 
158
+ text_input = tokenizer.value(prompt, padding="max_length", max_length=tokenizer.value.model_max_length, truncation=True, return_tensors="pt")
159
 
160
+ text_embeddings = text_encoder.value(text_input.input_ids.to(device))[0]
161
 
162
  max_length = text_input.input_ids.shape[-1]
163
+ uncond_input = tokenizer.value(
164
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
165
  )
166
+ uncond_embeddings = text_encoder.value(uncond_input.input_ids.to(device))[0]
167
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
168
+ noise_scheduler.value.set_timesteps(ddim_steps)
169
+ latents = latents * noise_scheduler.value.init_noise_sigma
170
 
171
 
172
 
173
+ for i,t in enumerate(tqdm.tqdm(noise_scheduler.value.timesteps)):
174
  latent_model_input = torch.cat([latents] * 2)
175
+ latent_model_input = noise_scheduler.value.scale_model_input(latent_model_input, timestep=t)
176
 
177
  if t>start_noise:
178
  pass
179
  elif t<=start_noise:
180
+ network.value.proj = torch.nn.Parameter(edited_weights)
181
+ network.value.reset()
182
 
183
 
184
  with network:
185
+ noise_pred = unet.value(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
186
 
187
 
188
  #guidance
189
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
190
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
191
+ latents = noise_scheduler.value.step(noise_pred, t, latents).prev_sample
192
 
193
  latents = 1 / 0.18215 * latents
194
+ image = vae.value.decode(latents).sample
195
  image = (image / 2 + 0.5).clamp(0, 1)
196
 
197
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
 
199
  image = Image.fromarray((image * 255).round().astype("uint8"))
200
 
201
  #reset weights back to original
202
+ network.value.proj = torch.nn.Parameter(original_weights)
203
+ network.value.reset()
204
 
205
  return image
206
 
 
213
  cfg = 3.0
214
  steps = 25
215
  image = inference( prompt, negative_prompt, cfg, steps, seed)
216
+ torch.save(network.value.proj, "model.pt" )
217
  return image, "model.pt"
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  class CustomImageDataset(Dataset):
221
  def __init__(self, images, transform=None):