quickjkee commited on
Commit
3a99770
·
verified ·
1 Parent(s): f3f05b4

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +55 -0
pipeline.py CHANGED
@@ -16,6 +16,18 @@
16
  import torch
17
  from typing import Any, Callable, Dict, List, Union, Optional
18
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
 
 
 
 
 
 
 
 
 
 
 
 
19
  from diffusers.utils import (
20
  USE_PEFT_BACKEND,
21
  is_torch_xla_available,
@@ -38,6 +50,49 @@ else:
38
 
39
  class SwDPipeline(DiffusionPipeline):
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  @torch.no_grad()
42
  def __call__(
43
  self,
 
16
  import torch
17
  from typing import Any, Callable, Dict, List, Union, Optional
18
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
19
+ from diffusers.models.autoencoders import AutoencoderKL
20
+ from diffusers.models.transformers import SD3Transformer2DModel
21
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
22
+ from transformers import (
23
+ CLIPTextModelWithProjection,
24
+ CLIPTokenizer,
25
+ SiglipImageProcessor,
26
+ SiglipVisionModel,
27
+ T5EncoderModel,
28
+ T5TokenizerFast,
29
+ )
30
+
31
  from diffusers.utils import (
32
  USE_PEFT_BACKEND,
33
  is_torch_xla_available,
 
50
 
51
  class SwDPipeline(DiffusionPipeline):
52
 
53
+ def __init__(
54
+ self,
55
+ transformer: SD3Transformer2DModel,
56
+ scheduler: FlowMatchEulerDiscreteScheduler,
57
+ vae: AutoencoderKL,
58
+ text_encoder: CLIPTextModelWithProjection,
59
+ tokenizer: CLIPTokenizer,
60
+ text_encoder_2: CLIPTextModelWithProjection,
61
+ tokenizer_2: CLIPTokenizer,
62
+ text_encoder_3: T5EncoderModel,
63
+ tokenizer_3: T5TokenizerFast,
64
+ image_encoder: SiglipVisionModel = None,
65
+ feature_extractor: SiglipImageProcessor = None,
66
+ ):
67
+ super().__init__()
68
+
69
+ self.register_modules(
70
+ vae=vae,
71
+ text_encoder=text_encoder,
72
+ text_encoder_2=text_encoder_2,
73
+ text_encoder_3=text_encoder_3,
74
+ tokenizer=tokenizer,
75
+ tokenizer_2=tokenizer_2,
76
+ tokenizer_3=tokenizer_3,
77
+ transformer=transformer,
78
+ scheduler=scheduler,
79
+ image_encoder=image_encoder,
80
+ feature_extractor=feature_extractor,
81
+ )
82
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
83
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
84
+ self.tokenizer_max_length = (
85
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
86
+ )
87
+ self.default_sample_size = (
88
+ self.transformer.config.sample_size
89
+ if hasattr(self, "transformer") and self.transformer is not None
90
+ else 128
91
+ )
92
+ self.patch_size = (
93
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
94
+ )
95
+
96
  @torch.no_grad()
97
  def __call__(
98
  self,