manbeast3b commited on
Commit
03287aa
·
verified ·
1 Parent(s): b4ec046

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +15 -12
src/pipeline.py CHANGED
@@ -577,10 +577,8 @@ torch.backends.cuda.matmul.allow_tf32 = True
577
  torch.backends.cudnn.enabled = True
578
  torch.backends.cudnn.benchmark = True
579
 
580
- # ckpt_id = "black-forest-labs/FLUX.1-schnell"
581
- # ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
582
- ckpt_id = "silentdriver/4b68f38c0b"
583
- ckpt_revision = "36a3cf4a9f733fc5f31257099b56b304fb2eceab"
584
  def empty_cache():
585
  gc.collect()
586
  torch.cuda.empty_cache()
@@ -591,36 +589,41 @@ def load_pipeline() -> Pipeline:
591
  empty_cache()
592
 
593
  dtype, device = torch.bfloat16, "cuda"
 
 
 
 
594
 
595
  text_encoder_2 = T5EncoderModel.from_pretrained(
596
  "city96/t5-v1_1-xxl-encoder-bf16", revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16
597
  ).to(memory_format=torch.channels_last)
598
-
599
-
 
 
 
600
  vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=dtype)
601
 
602
- path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
603
  generator = torch.Generator(device=device)
604
  model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False, generator= generator).to(memory_format=torch.channels_last)
605
  torch.backends.cudnn.benchmark = True
606
  torch.backends.cudnn.deterministic = False
607
- # model = torch.compile(model, mode="max-autotune-no-cudagraphs")
608
- # model = torch.compile(model,backend="aot_eager")
609
- vae = torch.compile(vae)
610
  pipeline = DiffusionPipeline.from_pretrained(
611
  ckpt_id,
 
612
  vae=vae,
613
  revision=ckpt_revision,
614
  transformer=model,
615
  text_encoder_2=text_encoder_2,
616
  torch_dtype=dtype,
617
  ).to(device)
 
618
  pipeline.vae.requires_grad_(False)
619
  pipeline.transformer.requires_grad_(False)
620
  pipeline.text_encoder_2.requires_grad_(False)
621
  pipeline.text_encoder.requires_grad_(False)
622
-
623
- # pipeline.enable_sequential_cpu_offload(exclude=["transformer"])
624
 
625
  for _ in range(3):
626
  pipeline(prompt="blah blah waah waah oneshot oneshot gang gang", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
 
577
  torch.backends.cudnn.enabled = True
578
  torch.backends.cudnn.benchmark = True
579
 
580
+ ckpt_id = "black-forest-labs/FLUX.1-schnell"
581
+ ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
 
 
582
  def empty_cache():
583
  gc.collect()
584
  torch.cuda.empty_cache()
 
589
  empty_cache()
590
 
591
  dtype, device = torch.bfloat16, "cuda"
592
+
593
+ text_encoder = CLIPTextModel.from_pretrained(
594
+ ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
595
+ )
596
 
597
  text_encoder_2 = T5EncoderModel.from_pretrained(
598
  "city96/t5-v1_1-xxl-encoder-bf16", revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16
599
  ).to(memory_format=torch.channels_last)
600
+
601
+ text_encoder = CLIPTextModel.from_pretrained(
602
+ os.path.join(HF_HUB_CACHE, "models--manbeast3b--FLUX.1-schnell-te1/snapshots/05ac3e466d6b42b7794859560d875b25f6df5daf"), subfolder="text_encoder", torch_dtype=torch.bfloat16
603
+ ).to(memory_format=torch.channels_last)
604
+
605
  vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=dtype)
606
 
607
+ path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--FLUX.1-schnell-transformer-f8/snapshots/2ac0d29a2f3a00175fd638e82e8acaa4ddcbfd09")
608
  generator = torch.Generator(device=device)
609
  model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False, generator= generator).to(memory_format=torch.channels_last)
610
  torch.backends.cudnn.benchmark = True
611
  torch.backends.cudnn.deterministic = False
612
+
 
 
613
  pipeline = DiffusionPipeline.from_pretrained(
614
  ckpt_id,
615
+ text_encoder=text_encoder,
616
  vae=vae,
617
  revision=ckpt_revision,
618
  transformer=model,
619
  text_encoder_2=text_encoder_2,
620
  torch_dtype=dtype,
621
  ).to(device)
622
+ pipeline.vae = torch.compile(pipeline.vae)
623
  pipeline.vae.requires_grad_(False)
624
  pipeline.transformer.requires_grad_(False)
625
  pipeline.text_encoder_2.requires_grad_(False)
626
  pipeline.text_encoder.requires_grad_(False)
 
 
627
 
628
  for _ in range(3):
629
  pipeline(prompt="blah blah waah waah oneshot oneshot gang gang", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)