Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -23,12 +23,12 @@ class ModelWrapper:
|
|
23 |
self.DTYPE = torch.float16
|
24 |
self.device = 0
|
25 |
|
26 |
-
self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, revision=revision, use_fast=False)
|
27 |
-
self.tokenizer_two = AutoTokenizer.from_pretrained(model_id, revision=revision, use_fast=False)
|
28 |
|
29 |
self.text_encoder = SDXLTextEncoder(model_id, revision, accelerator, dtype=self.DTYPE)
|
30 |
|
31 |
-
self.vae = AutoencoderKL.from_pretrained(model_id).float().to(self.device)
|
32 |
self.vae_dtype = torch.float32
|
33 |
|
34 |
self.tiny_vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=self.DTYPE).to(self.device)
|
@@ -43,12 +43,12 @@ class ModelWrapper:
|
|
43 |
self.vae_downsample_ratio = image_resolution // latent_resolution
|
44 |
self.conditioning_timestep = conditioning_timestep
|
45 |
|
46 |
-
self.scheduler = DDIMScheduler.from_pretrained(model_id)
|
47 |
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
48 |
self.num_step = num_step
|
49 |
|
50 |
def create_generator(self, model_id, checkpoint_path):
|
51 |
-
generator = UNet2DConditionModel.from_pretrained(model_id).to(self.DTYPE)
|
52 |
state_dict = torch.load(checkpoint_path)
|
53 |
generator.load_state_dict(state_dict, strict=True)
|
54 |
generator.requires_grad_(False)
|
@@ -172,8 +172,8 @@ class SDXLTextEncoder(torch.nn.Module):
|
|
172 |
def __init__(self, model_id, revision, accelerator, dtype=torch.float16):
|
173 |
super().__init__()
|
174 |
|
175 |
-
self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, revision=revision).to(0).to(dtype=dtype)
|
176 |
-
self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(model_id, revision=revision).to(0).to(dtype=dtype)
|
177 |
|
178 |
self.accelerator = accelerator
|
179 |
|
|
|
23 |
self.DTYPE = torch.float16
|
24 |
self.device = 0
|
25 |
|
26 |
+
self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
|
27 |
+
self.tokenizer_two = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
|
28 |
|
29 |
self.text_encoder = SDXLTextEncoder(model_id, revision, accelerator, dtype=self.DTYPE)
|
30 |
|
31 |
+
self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").float().to(self.device)
|
32 |
self.vae_dtype = torch.float32
|
33 |
|
34 |
self.tiny_vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=self.DTYPE).to(self.device)
|
|
|
43 |
self.vae_downsample_ratio = image_resolution // latent_resolution
|
44 |
self.conditioning_timestep = conditioning_timestep
|
45 |
|
46 |
+
self.scheduler = DDIMScheduler.from_pretrained(model_id,subfolder="scheduler")
|
47 |
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
48 |
self.num_step = num_step
|
49 |
|
50 |
def create_generator(self, model_id, checkpoint_path):
|
51 |
+
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
|
52 |
state_dict = torch.load(checkpoint_path)
|
53 |
generator.load_state_dict(state_dict, strict=True)
|
54 |
generator.requires_grad_(False)
|
|
|
172 |
def __init__(self, model_id, revision, accelerator, dtype=torch.float16):
|
173 |
super().__init__()
|
174 |
|
175 |
+
self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(0).to(dtype=dtype)
|
176 |
+
self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2", revision=revision).to(0).to(dtype=dtype)
|
177 |
|
178 |
self.accelerator = accelerator
|
179 |
|