update load text encoder
Browse files- flux/modules/conditioner.py +13 -5
- flux/util.py +4 -2
flux/modules/conditioner.py
CHANGED
@@ -10,12 +10,20 @@ class HFEmbedder(nn.Module):
|
|
10 |
self.max_length = max_length
|
11 |
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
12 |
|
13 |
-
if
|
14 |
-
self.
|
15 |
-
|
|
|
|
|
|
|
|
|
16 |
else:
|
17 |
-
self.
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
21 |
|
|
|
10 |
self.max_length = max_length
|
11 |
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
12 |
|
13 |
+
if version == 'black-forest-labs/FLUX.1-dev':
|
14 |
+
if self.is_clip:
|
15 |
+
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length, subfolder="tokenizer")
|
16 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version,subfolder='text_encoder' , **hf_kwargs)
|
17 |
+
else:
|
18 |
+
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length, subfolder="tokenizer_2")
|
19 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version,subfolder='text_encoder_2' , **hf_kwargs)
|
20 |
else:
|
21 |
+
if self.is_clip:
|
22 |
+
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
23 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
24 |
+
else:
|
25 |
+
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
|
26 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
|
27 |
|
28 |
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
29 |
|
flux/util.py
CHANGED
@@ -128,11 +128,13 @@ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download:
|
|
128 |
|
129 |
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
130 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
131 |
-
return HFEmbedder("
|
|
|
132 |
|
133 |
|
134 |
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
135 |
-
return HFEmbedder("
|
|
|
136 |
|
137 |
|
138 |
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|
|
|
128 |
|
129 |
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
130 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
131 |
+
return HFEmbedder("black-forest-labs/FLUX.1-dev", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
|
132 |
+
# return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
|
133 |
|
134 |
|
135 |
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
136 |
+
return HFEmbedder("black-forest-labs/FLUX.1-dev", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
|
137 |
+
# return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
|
138 |
|
139 |
|
140 |
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|