lnyan commited on
Commit
5d69e29
·
1 Parent(s): 5cea7d2
Files changed (2) hide show
  1. app.py +1 -1
  2. flux/modules/conditioner.py +6 -4
app.py CHANGED
@@ -266,5 +266,5 @@ def create_demo(model_name: str, device: str = "cuda", offload: bool = False):
266
  # parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
267
  # args = parser.parse_args()
268
 
269
- demo = create_demo("flux-dev", None, False)
270
  demo.launch()
 
266
  # parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
267
  # args = parser.parse_args()
268
 
269
+ demo = create_demo("flux-schnell", None, False)
270
  demo.launch()
flux/modules/conditioner.py CHANGED
@@ -15,13 +15,15 @@ class HFEmbedder(nnx.Module):
15
  if self.is_clip:
16
  self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
17
  # self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
18
- self.hf_module: FlaxCLIPTextModel = FlaxCLIPTextModel.from_pretrained(version, **hf_kwargs)
19
  else:
20
  self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
21
  # self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
22
- self.hf_module: FlaxT5EncoderModel = FlaxT5EncoderModel.from_pretrained(version, **hf_kwargs)
23
- if dtype==jnp.bfloat16:
24
- self.hf_module.params = self.hf_module.to_bf16(self.hf_module.params)
 
 
25
 
26
  def tokenize(self, text: list[str]) -> Tensor:
27
  batch_encoding = self.tokenizer(
 
15
  if self.is_clip:
16
  self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
17
  # self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
18
+ self.hf_module, params = FlaxCLIPTextModel.from_pretrained(version, _do_init=False, **hf_kwargs)
19
  else:
20
  self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
21
  # self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
22
+ self.hf_module, params = FlaxT5EncoderModel.from_pretrained(version, _do_init=False,**hf_kwargs)
23
+ self.hf_module._is_initialized = True
24
+ import jax
25
+ self.hf_module.params = jax.tree_map(lambda x: jax.device_put(x, jax.devices("cuda")[0]), params)
26
+ # if dtype==jnp.bfloat16:
27
 
28
  def tokenize(self, text: list[str]) -> Tensor:
29
  batch_encoding = self.tokenizer(