Spaces:
vilarin
/
Running on Zero

vilarin commited on
Commit
3c4c329
·
verified ·
1 Parent(s): 17931cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -23,12 +23,12 @@ class ModelWrapper:
23
  self.DTYPE = torch.float16
24
  self.device = 0
25
 
26
- self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
27
- self.tokenizer_two = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
28
 
29
  self.text_encoder = SDXLTextEncoder(model_id, revision, accelerator, dtype=self.DTYPE)
30
 
31
- self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").float().to(self.device)
32
  self.vae_dtype = torch.float32
33
 
34
  self.tiny_vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=self.DTYPE).to(self.device)
@@ -43,12 +43,12 @@ class ModelWrapper:
43
  self.vae_downsample_ratio = image_resolution // latent_resolution
44
  self.conditioning_timestep = conditioning_timestep
45
 
46
- self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
47
  self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
48
  self.num_step = num_step
49
 
50
  def create_generator(self, model_id, checkpoint_path):
51
- generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
52
  state_dict = torch.load(checkpoint_path)
53
  generator.load_state_dict(state_dict, strict=True)
54
  generator.requires_grad_(False)
@@ -80,7 +80,6 @@ class ModelWrapper:
80
  @spaces.GPU()
81
  def sample(self, noise, unet_added_conditions, prompt_embed, fast_vae_decode):
82
  alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
83
- print(alphas_cumprod)
84
  if self.num_step == 1:
85
  all_timesteps = [self.conditioning_timestep]
86
  step_interval = 0
@@ -173,8 +172,8 @@ class SDXLTextEncoder(torch.nn.Module):
173
  def __init__(self, model_id, revision, accelerator, dtype=torch.float16):
174
  super().__init__()
175
 
176
- self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(0).to(dtype=dtype)
177
- self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2", revision=revision).to(0).to(dtype=dtype)
178
 
179
  self.accelerator = accelerator
180
 
 
23
  self.DTYPE = torch.float16
24
  self.device = 0
25
 
26
+ self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, revision=revision, use_fast=False)
27
+ self.tokenizer_two = AutoTokenizer.from_pretrained(model_id, revision=revision, use_fast=False)
28
 
29
  self.text_encoder = SDXLTextEncoder(model_id, revision, accelerator, dtype=self.DTYPE)
30
 
31
+ self.vae = AutoencoderKL.from_pretrained(model_id).float().to(self.device)
32
  self.vae_dtype = torch.float32
33
 
34
  self.tiny_vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=self.DTYPE).to(self.device)
 
43
  self.vae_downsample_ratio = image_resolution // latent_resolution
44
  self.conditioning_timestep = conditioning_timestep
45
 
46
+ self.scheduler = DDIMScheduler.from_pretrained(model_id)
47
  self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
48
  self.num_step = num_step
49
 
50
  def create_generator(self, model_id, checkpoint_path):
51
+ generator = UNet2DConditionModel.from_pretrained(model_id).to(self.DTYPE)
52
  state_dict = torch.load(checkpoint_path)
53
  generator.load_state_dict(state_dict, strict=True)
54
  generator.requires_grad_(False)
 
80
  @spaces.GPU()
81
  def sample(self, noise, unet_added_conditions, prompt_embed, fast_vae_decode):
82
  alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
 
83
  if self.num_step == 1:
84
  all_timesteps = [self.conditioning_timestep]
85
  step_interval = 0
 
172
  def __init__(self, model_id, revision, accelerator, dtype=torch.float16):
173
  super().__init__()
174
 
175
+ self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, revision=revision).to(0).to(dtype=dtype)
176
+ self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(model_id, revision=revision).to(0).to(dtype=dtype)
177
 
178
  self.accelerator = accelerator
179