amildravid4292 commited on
Commit
0bc198e
·
verified ·
1 Parent(s): 7839c32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -165,7 +165,11 @@ def sample_weights(unet, proj, mean, std, v, device, factor = 1.0):
165
  @torch.no_grad()
166
  @spaces.GPU
167
  def sample_model():
168
- unet.value, _, _, _, _ = load_models()
 
 
 
 
169
  network.value = sample_weights(unet.value, proj, mean, std, v[:, :1000], device.value, factor = 1.00)
170
 
171
  @torch.no_grad()
 
165
  @torch.no_grad()
166
  @spaces.GPU
167
  def sample_model():
168
+ unet.value = UNet2DConditionModel.from_pretrained(
169
+ pretrained_model_name_or_path, subfolder="unet", revision=revision
170
+ )
171
+ unet.value.requires_grad_(False)
172
+ unet.value.to(device.value, dtype=weight_dtype)
173
  network.value = sample_weights(unet.value, proj, mean, std, v[:, :1000], device.value, factor = 1.00)
174
 
175
  @torch.no_grad()