Spaces:
vilarin
/
Running on Zero

vilarin commited on
Commit
3968f74
·
verified ·
1 Parent(s): 3c4c329

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -23,12 +23,12 @@ class ModelWrapper:
23
  self.DTYPE = torch.float16
24
  self.device = 0
25
 
26
- self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, revision=revision, use_fast=False)
27
- self.tokenizer_two = AutoTokenizer.from_pretrained(model_id, revision=revision, use_fast=False)
28
 
29
  self.text_encoder = SDXLTextEncoder(model_id, revision, accelerator, dtype=self.DTYPE)
30
 
31
- self.vae = AutoencoderKL.from_pretrained(model_id).float().to(self.device)
32
  self.vae_dtype = torch.float32
33
 
34
  self.tiny_vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=self.DTYPE).to(self.device)
@@ -43,12 +43,12 @@ class ModelWrapper:
43
  self.vae_downsample_ratio = image_resolution // latent_resolution
44
  self.conditioning_timestep = conditioning_timestep
45
 
46
- self.scheduler = DDIMScheduler.from_pretrained(model_id)
47
  self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
48
  self.num_step = num_step
49
 
50
  def create_generator(self, model_id, checkpoint_path):
51
- generator = UNet2DConditionModel.from_pretrained(model_id).to(self.DTYPE)
52
  state_dict = torch.load(checkpoint_path)
53
  generator.load_state_dict(state_dict, strict=True)
54
  generator.requires_grad_(False)
@@ -172,8 +172,8 @@ class SDXLTextEncoder(torch.nn.Module):
172
  def __init__(self, model_id, revision, accelerator, dtype=torch.float16):
173
  super().__init__()
174
 
175
- self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, revision=revision).to(0).to(dtype=dtype)
176
- self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(model_id, revision=revision).to(0).to(dtype=dtype)
177
 
178
  self.accelerator = accelerator
179
 
 
23
  self.DTYPE = torch.float16
24
  self.device = 0
25
 
26
+ self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
27
+ self.tokenizer_two = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
28
 
29
  self.text_encoder = SDXLTextEncoder(model_id, revision, accelerator, dtype=self.DTYPE)
30
 
31
+ self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").float().to(self.device)
32
  self.vae_dtype = torch.float32
33
 
34
  self.tiny_vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=self.DTYPE).to(self.device)
 
43
  self.vae_downsample_ratio = image_resolution // latent_resolution
44
  self.conditioning_timestep = conditioning_timestep
45
 
46
+ self.scheduler = DDIMScheduler.from_pretrained(model_id,subfolder="scheduler")
47
  self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
48
  self.num_step = num_step
49
 
50
  def create_generator(self, model_id, checkpoint_path):
51
+ generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
52
  state_dict = torch.load(checkpoint_path)
53
  generator.load_state_dict(state_dict, strict=True)
54
  generator.requires_grad_(False)
 
172
  def __init__(self, model_id, revision, accelerator, dtype=torch.float16):
173
  super().__init__()
174
 
175
+ self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(0).to(dtype=dtype)
176
+ self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2", revision=revision).to(0).to(dtype=dtype)
177
 
178
  self.accelerator = accelerator
179