xilluill commited on
Commit
c27f75b
·
1 Parent(s): 4291344

update load text encoder

Browse files
Files changed (2) hide show
  1. flux/modules/conditioner.py +13 -5
  2. flux/util.py +4 -2
flux/modules/conditioner.py CHANGED
@@ -10,12 +10,20 @@ class HFEmbedder(nn.Module):
10
  self.max_length = max_length
11
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
12
 
13
- if self.is_clip:
14
- self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
15
- self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
 
 
 
 
16
  else:
17
- self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
18
- self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
 
 
 
 
19
 
20
  self.hf_module = self.hf_module.eval().requires_grad_(False)
21
 
 
10
  self.max_length = max_length
11
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
12
 
13
+ if version == 'black-forest-labs/FLUX.1-dev':
14
+ if self.is_clip:
15
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length, subfolder="tokenizer")
16
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version,subfolder='text_encoder' , **hf_kwargs)
17
+ else:
18
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length, subfolder="tokenizer_2")
19
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version,subfolder='text_encoder_2' , **hf_kwargs)
20
  else:
21
+ if self.is_clip:
22
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
23
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
24
+ else:
25
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
26
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
27
 
28
  self.hf_module = self.hf_module.eval().requires_grad_(False)
29
 
flux/util.py CHANGED
@@ -128,11 +128,13 @@ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download:
128
 
129
  def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
130
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
131
- return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
 
132
 
133
 
134
  def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
135
- return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
 
136
 
137
 
138
  def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
 
128
 
129
  def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
130
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
131
+ return HFEmbedder("black-forest-labs/FLUX.1-dev", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
132
+ # return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
133
 
134
 
135
  def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
136
+ return HFEmbedder("black-forest-labs/FLUX.1-dev", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
137
+ # return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
138
 
139
 
140
  def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: