manbeast3b commited on
Commit
5560a1b
·
verified ·
1 Parent(s): ef90742

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +44 -0
src/pipeline.py CHANGED
@@ -572,6 +572,42 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
572
 
573
  return Transformer2DModelOutput(sample=output)
574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  Pipeline = None
576
  torch.backends.cuda.matmul.allow_tf32 = True
577
  torch.backends.cudnn.enabled = True
@@ -594,7 +630,15 @@ def load_pipeline() -> Pipeline:
594
  from diffusers.loaders.single_file_utils import create_diffusers_t5_model_from_checkpoint
595
  from diffusers.loaders.single_file_model import FromOriginalModelMixin
596
 
 
597
  dtype, device = torch.bfloat16, "cuda"
 
 
 
 
 
 
 
598
 
599
  text_encoder_2 = T5EncoderModel.from_pretrained(
600
  "silentdriver/aadb864af9", revision = "060dabc7fa271c26dfa3fd43c16e7c5bf3ac7892", torch_dtype=torch.bfloat16
 
572
 
573
  return Transformer2DModelOutput(sample=output)
574
 
575
+ def load_single_file_checkpoint(
576
+ pretrained_model_link_or_path,
577
+ force_download=False,
578
+ proxies=None,
579
+ token=None,
580
+ cache_dir=None,
581
+ local_files_only=None,
582
+ revision=None,
583
+ ):
584
+ import pdb; pdb.set_trace()
585
+ if os.path.isfile(pretrained_model_link_or_path):
586
+ pretrained_model_link_or_path = pretrained_model_link_or_path
587
+
588
+ else:
589
+ repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
590
+ pretrained_model_link_or_path = _get_model_file(
591
+ repo_id,
592
+ weights_name=weights_name,
593
+ force_download=force_download,
594
+ cache_dir=cache_dir,
595
+ proxies=proxies,
596
+ local_files_only=local_files_only,
597
+ token=token,
598
+ revision=revision,
599
+ )
600
+ import pdb; pdb.set_trace()
601
+
602
+ checkpoint = load_state_dict(pretrained_model_link_or_path)
603
+
604
+ # some checkpoints contain the model state dict under a "state_dict" key
605
+ while "state_dict" in checkpoint:
606
+ checkpoint = checkpoint["state_dict"]
607
+
608
+ return checkpoint
609
+
610
+
611
  Pipeline = None
612
  torch.backends.cuda.matmul.allow_tf32 = True
613
  torch.backends.cudnn.enabled = True
 
630
  from diffusers.loaders.single_file_utils import create_diffusers_t5_model_from_checkpoint
631
  from diffusers.loaders.single_file_model import FromOriginalModelMixin
632
 
633
+
634
  dtype, device = torch.bfloat16, "cuda"
635
+
636
+ import pdb; pdb.set_trace()
637
+
638
+ t5_path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--t5-v1_1-xxl-encoder-q8/snapshots/59c6c9cb99dcea42067f32caac3ea0836ef4c548/t5-v1_1-xxl-encoder-Q8_0.gguf")
639
+ # config_path = os.path.join(HF_HUB_CACHE, "models--black-forest--labs/FLUX.1-schnell/snapshots/741f7c3ce8b383c54771c7003378a50191e9efe9/text_encoder_2/config.json")
640
+ config_path = os.path.join(HF_HUB_CACHE, "models--black-forest-labs--FLUX.1-schnell/snapshots/741f7c3ce8b383c54771c7003378a50191e9efe9/")
641
+ ckpt_t5 = load_single_file_checkpoint(t5_path,local_files_only=True)
642
 
643
  text_encoder_2 = T5EncoderModel.from_pretrained(
644
  "silentdriver/aadb864af9", revision = "060dabc7fa271c26dfa3fd43c16e7c5bf3ac7892", torch_dtype=torch.bfloat16