qyoo commited on
Commit
88c3674
·
1 Parent(s): 477bc72

update to bf16

Browse files
app.py CHANGED
@@ -61,7 +61,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
  adapter_name = "h94/IP-Adapter/models/ip-adapter-plus_sd15.bin"
62
  pipe = StableDiffusionCustomPipeline.from_pretrained(
63
  "SG161222/Realistic_Vision_V5.1_noVAE",
64
- torch_dtype=torch.float16,
65
  feature_extractor=None,
66
  safety_checker=None
67
  )
@@ -91,7 +91,7 @@ def change_model_fn(model_name: str) -> None:
91
  pipe = StableDiffusionXLCustomPipeline.from_pretrained(
92
  name_mapping[model_name],
93
  # variant="fp16",
94
- torch_dtype=torch.float16,
95
  feature_extractor=None
96
  )
97
  pipeline = ConceptrolIPAdapterPlusXL(pipe, "", adapter_name, device, num_tokens=16)
@@ -117,7 +117,7 @@ def change_model_fn(model_name: str) -> None:
117
  adapter_name = "h94/IP-Adapter/models/ip-adapter-plus_sd15.bin"
118
  pipe = StableDiffusionCustomPipeline.from_pretrained(
119
  name_mapping[model_name],
120
- torch_dtype=torch.float16,
121
  feature_extractor=None,
122
  safety_checker=None
123
  )
@@ -389,4 +389,4 @@ with gr.Blocks(css="style.css") as demo:
389
  )
390
  gr.Markdown(article)
391
 
392
- demo.launch(share=True)
 
61
  adapter_name = "h94/IP-Adapter/models/ip-adapter-plus_sd15.bin"
62
  pipe = StableDiffusionCustomPipeline.from_pretrained(
63
  "SG161222/Realistic_Vision_V5.1_noVAE",
64
+ torch_dtype=torch.bfloat16,
65
  feature_extractor=None,
66
  safety_checker=None
67
  )
 
91
  pipe = StableDiffusionXLCustomPipeline.from_pretrained(
92
  name_mapping[model_name],
93
  # variant="fp16",
94
+ torch_dtype=torch.bfloat16,
95
  feature_extractor=None
96
  )
97
  pipeline = ConceptrolIPAdapterPlusXL(pipe, "", adapter_name, device, num_tokens=16)
 
117
  adapter_name = "h94/IP-Adapter/models/ip-adapter-plus_sd15.bin"
118
  pipe = StableDiffusionCustomPipeline.from_pretrained(
119
  name_mapping[model_name],
120
+ torch_dtype=torch.bfloat16,
121
  feature_extractor=None,
122
  safety_checker=None
123
  )
 
389
  )
390
  gr.Markdown(article)
391
 
392
+ demo.launch()
ip_adapter/custom_pipelines.py CHANGED
@@ -408,7 +408,7 @@ class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline):
408
  if not output_type == "latent":
409
  # make sure the VAE is in float32 mode, as it overflows in float16
410
  needs_upcasting = (
411
- self.vae.dtype == torch.float16 and self.vae.config.force_upcast
412
  )
413
 
414
  if needs_upcasting:
@@ -423,7 +423,7 @@ class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline):
423
 
424
  # cast back to fp16 if needed
425
  if needs_upcasting:
426
- self.vae.to(dtype=torch.float16)
427
  else:
428
  image = latents
429
 
 
408
  if not output_type == "latent":
409
  # make sure the VAE is in float32 mode, as it overflows in float16
410
  needs_upcasting = (
411
+ self.vae.dtype == torch.bfloat16 and self.vae.config.force_upcast
412
  )
413
 
414
  if needs_upcasting:
 
423
 
424
  # cast back to fp16 if needed
425
  if needs_upcasting:
426
+ self.vae.to(dtype=torch.bfloat16)
427
  else:
428
  image = latents
429
 
ip_adapter/ip_adapter.py CHANGED
@@ -85,8 +85,8 @@ class IPAdapter:
85
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
86
  "h94/IP-Adapter",
87
  subfolder="models/image_encoder",
88
- torch_dtype=torch.float16,
89
- ).to(self.device, dtype=torch.float16)
90
  self.clip_image_processor = CLIPImageProcessor()
91
  # image proj model
92
  self.image_proj_model = self.init_proj()
@@ -98,7 +98,7 @@ class IPAdapter:
98
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
99
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
100
  clip_extra_context_tokens=self.num_tokens,
101
- ).to(self.device, dtype=torch.float16)
102
  return image_proj_model
103
 
104
  def set_ip_adapter(self):
@@ -126,7 +126,7 @@ class IPAdapter:
126
  cross_attention_dim=cross_attention_dim,
127
  scale=1.0,
128
  num_tokens=self.num_tokens,
129
- ).to(self.device, dtype=torch.float16)
130
  unet.set_attn_processor(attn_procs)
131
  if hasattr(self.pipe, "controlnet"):
132
  if isinstance(self.pipe.controlnet, MultiControlNetModel):
@@ -167,10 +167,10 @@ class IPAdapter:
167
  images=pil_image, return_tensors="pt"
168
  ).pixel_values
169
  clip_image_embeds = self.image_encoder(
170
- clip_image.to(self.device, dtype=torch.float16)
171
  ).image_embeds
172
  else:
173
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
174
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
175
  uncond_image_prompt_embeds = self.image_proj_model(
176
  torch.zeros_like(clip_image_embeds)
@@ -282,8 +282,8 @@ class ConceptrolIPAdapter:
282
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
283
  "h94/IP-Adapter",
284
  subfolder="models/image_encoder",
285
- torch_dtype=torch.float16,
286
- ).to(self.device, dtype=torch.float16)
287
  self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
288
  self.clip_image_processor = CLIPImageProcessor()
289
  # image proj model
@@ -296,7 +296,7 @@ class ConceptrolIPAdapter:
296
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
297
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
298
  clip_extra_context_tokens=self.num_tokens,
299
- ).to(self.device, dtype=torch.float16)
300
  return image_proj_model
301
 
302
  def set_ip_adapter(self, global_masking, adaptive_scale_mask):
@@ -328,7 +328,7 @@ class ConceptrolIPAdapter:
328
  global_masking=global_masking,
329
  adaptive_scale_mask=adaptive_scale_mask,
330
  concept_mask_layer=SD_CONCEPT_LAYER,
331
- ).to(self.device, dtype=torch.float16)
332
  unet.set_attn_processor(attn_procs)
333
  for name in unet.attn_processors.keys(): # noqa: SIM118
334
  cross_attention_dim = (
@@ -395,10 +395,10 @@ class ConceptrolIPAdapter:
395
  images=pil_image, return_tensors="pt"
396
  ).pixel_values
397
  clip_image_embeds = self.image_encoder(
398
- clip_image.to(self.device, dtype=torch.float16)
399
  ).image_embeds
400
  else:
401
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
402
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
403
  uncond_image_prompt_embeds = self.image_proj_model(
404
  torch.zeros_like(clip_image_embeds)
@@ -624,7 +624,7 @@ class ConceptrolIPAdapterXL(ConceptrolIPAdapter):
624
  global_masking=global_masking,
625
  adaptive_scale_mask=adaptive_scale_mask,
626
  concept_mask_layer=SDXL_CONCEPT_LAYER,
627
- ).to(self.device, dtype=torch.float16)
628
  unet.set_attn_processor(attn_procs)
629
  for name in unet.attn_processors.keys(): # noqa: SIM118
630
  cross_attention_dim = (
@@ -743,7 +743,7 @@ class IPAdapterPlus(IPAdapter):
743
  embedding_dim=self.image_encoder.config.hidden_size,
744
  output_dim=self.pipe.unet.config.cross_attention_dim,
745
  ff_mult=4,
746
- ).to(self.device, dtype=torch.float16)
747
  return image_proj_model
748
 
749
  @torch.inference_mode()
@@ -753,7 +753,7 @@ class IPAdapterPlus(IPAdapter):
753
  clip_image = self.clip_image_processor(
754
  images=pil_image, return_tensors="pt"
755
  ).pixel_values
756
- clip_image = clip_image.to(self.device, dtype=torch.float16)
757
  clip_image_embeds = self.image_encoder(
758
  clip_image, output_hidden_states=True
759
  ).hidden_states[-2]
@@ -778,7 +778,7 @@ class ConceptrolIPAdapterPlus(ConceptrolIPAdapter):
778
  embedding_dim=self.image_encoder.config.hidden_size,
779
  output_dim=self.pipe.unet.config.cross_attention_dim,
780
  ff_mult=4,
781
- ).to(self.device, dtype=torch.float16)
782
  return image_proj_model
783
 
784
  @torch.inference_mode()
@@ -788,7 +788,7 @@ class ConceptrolIPAdapterPlus(ConceptrolIPAdapter):
788
  clip_image = self.clip_image_processor(
789
  images=pil_image, return_tensors="pt"
790
  ).pixel_values
791
- clip_image = clip_image.to(self.device, dtype=torch.float16)
792
  clip_image_embeds = self.image_encoder(
793
  clip_image, output_hidden_states=True
794
  ).hidden_states[-2]
@@ -807,7 +807,7 @@ class IPAdapterFull(IPAdapterPlus):
807
  image_proj_model = MLPProjModel(
808
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
809
  clip_embeddings_dim=self.image_encoder.config.hidden_size,
810
- ).to(self.device, dtype=torch.float16)
811
  return image_proj_model
812
 
813
 
@@ -824,7 +824,7 @@ class IPAdapterPlusXL(IPAdapter):
824
  embedding_dim=self.image_encoder.config.hidden_size,
825
  output_dim=self.pipe.unet.config.cross_attention_dim,
826
  ff_mult=4,
827
- ).to(self.device, dtype=torch.float16)
828
  return image_proj_model
829
 
830
  @torch.inference_mode()
@@ -834,7 +834,7 @@ class IPAdapterPlusXL(IPAdapter):
834
  clip_image = self.clip_image_processor(
835
  images=pil_image, return_tensors="pt"
836
  ).pixel_values
837
- clip_image = clip_image.to(self.device, dtype=torch.float16)
838
  clip_image_embeds = self.image_encoder(
839
  clip_image, output_hidden_states=True
840
  ).hidden_states[-2]
@@ -937,7 +937,7 @@ class ConceptrolIPAdapterPlusXL(ConceptrolIPAdapterXL):
937
  embedding_dim=self.image_encoder.config.hidden_size,
938
  output_dim=self.pipe.unet.config.cross_attention_dim,
939
  ff_mult=4,
940
- ).to(self.device, dtype=torch.float16)
941
  return image_proj_model
942
 
943
  @torch.inference_mode()
@@ -947,7 +947,7 @@ class ConceptrolIPAdapterPlusXL(ConceptrolIPAdapterXL):
947
  clip_image = self.clip_image_processor(
948
  images=pil_image, return_tensors="pt"
949
  ).pixel_values
950
- clip_image = clip_image.to(self.device, dtype=torch.float16)
951
  clip_image_embeds = self.image_encoder(
952
  clip_image, output_hidden_states=True
953
  ).hidden_states[-2]
 
85
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
86
  "h94/IP-Adapter",
87
  subfolder="models/image_encoder",
88
+ torch_dtype=torch.bfloat16,
89
+ ).to(self.device, dtype=torch.bfloat16)
90
  self.clip_image_processor = CLIPImageProcessor()
91
  # image proj model
92
  self.image_proj_model = self.init_proj()
 
98
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
99
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
100
  clip_extra_context_tokens=self.num_tokens,
101
+ ).to(self.device, dtype=torch.bfloat16)
102
  return image_proj_model
103
 
104
  def set_ip_adapter(self):
 
126
  cross_attention_dim=cross_attention_dim,
127
  scale=1.0,
128
  num_tokens=self.num_tokens,
129
+ ).to(self.device, dtype=torch.bfloat16)
130
  unet.set_attn_processor(attn_procs)
131
  if hasattr(self.pipe, "controlnet"):
132
  if isinstance(self.pipe.controlnet, MultiControlNetModel):
 
167
  images=pil_image, return_tensors="pt"
168
  ).pixel_values
169
  clip_image_embeds = self.image_encoder(
170
+ clip_image.to(self.device, dtype=torch.bfloat16)
171
  ).image_embeds
172
  else:
173
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.bfloat16)
174
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
175
  uncond_image_prompt_embeds = self.image_proj_model(
176
  torch.zeros_like(clip_image_embeds)
 
282
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
283
  "h94/IP-Adapter",
284
  subfolder="models/image_encoder",
285
+ torch_dtype=torch.bfloat16,
286
+ ).to(self.device, dtype=torch.bfloat16)
287
  self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
288
  self.clip_image_processor = CLIPImageProcessor()
289
  # image proj model
 
296
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
297
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
298
  clip_extra_context_tokens=self.num_tokens,
299
+ ).to(self.device, dtype=torch.bfloat16)
300
  return image_proj_model
301
 
302
  def set_ip_adapter(self, global_masking, adaptive_scale_mask):
 
328
  global_masking=global_masking,
329
  adaptive_scale_mask=adaptive_scale_mask,
330
  concept_mask_layer=SD_CONCEPT_LAYER,
331
+ ).to(self.device, dtype=torch.bfloat16)
332
  unet.set_attn_processor(attn_procs)
333
  for name in unet.attn_processors.keys(): # noqa: SIM118
334
  cross_attention_dim = (
 
395
  images=pil_image, return_tensors="pt"
396
  ).pixel_values
397
  clip_image_embeds = self.image_encoder(
398
+ clip_image.to(self.device, dtype=torch.bfloat16)
399
  ).image_embeds
400
  else:
401
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.bfloat16)
402
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
403
  uncond_image_prompt_embeds = self.image_proj_model(
404
  torch.zeros_like(clip_image_embeds)
 
624
  global_masking=global_masking,
625
  adaptive_scale_mask=adaptive_scale_mask,
626
  concept_mask_layer=SDXL_CONCEPT_LAYER,
627
+ ).to(self.device, dtype=torch.bfloat16)
628
  unet.set_attn_processor(attn_procs)
629
  for name in unet.attn_processors.keys(): # noqa: SIM118
630
  cross_attention_dim = (
 
743
  embedding_dim=self.image_encoder.config.hidden_size,
744
  output_dim=self.pipe.unet.config.cross_attention_dim,
745
  ff_mult=4,
746
+ ).to(self.device, dtype=torch.bfloat16)
747
  return image_proj_model
748
 
749
  @torch.inference_mode()
 
753
  clip_image = self.clip_image_processor(
754
  images=pil_image, return_tensors="pt"
755
  ).pixel_values
756
+ clip_image = clip_image.to(self.device, dtype=torch.bfloat16)
757
  clip_image_embeds = self.image_encoder(
758
  clip_image, output_hidden_states=True
759
  ).hidden_states[-2]
 
778
  embedding_dim=self.image_encoder.config.hidden_size,
779
  output_dim=self.pipe.unet.config.cross_attention_dim,
780
  ff_mult=4,
781
+ ).to(self.device, dtype=torch.bfloat16)
782
  return image_proj_model
783
 
784
  @torch.inference_mode()
 
788
  clip_image = self.clip_image_processor(
789
  images=pil_image, return_tensors="pt"
790
  ).pixel_values
791
+ clip_image = clip_image.to(self.device, dtype=torch.bfloat16)
792
  clip_image_embeds = self.image_encoder(
793
  clip_image, output_hidden_states=True
794
  ).hidden_states[-2]
 
807
  image_proj_model = MLPProjModel(
808
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
809
  clip_embeddings_dim=self.image_encoder.config.hidden_size,
810
+ ).to(self.device, dtype=torch.bfloat16)
811
  return image_proj_model
812
 
813
 
 
824
  embedding_dim=self.image_encoder.config.hidden_size,
825
  output_dim=self.pipe.unet.config.cross_attention_dim,
826
  ff_mult=4,
827
+ ).to(self.device, dtype=torch.bfloat16)
828
  return image_proj_model
829
 
830
  @torch.inference_mode()
 
834
  clip_image = self.clip_image_processor(
835
  images=pil_image, return_tensors="pt"
836
  ).pixel_values
837
+ clip_image = clip_image.to(self.device, dtype=torch.bfloat16)
838
  clip_image_embeds = self.image_encoder(
839
  clip_image, output_hidden_states=True
840
  ).hidden_states[-2]
 
937
  embedding_dim=self.image_encoder.config.hidden_size,
938
  output_dim=self.pipe.unet.config.cross_attention_dim,
939
  ff_mult=4,
940
+ ).to(self.device, dtype=torch.bfloat16)
941
  return image_proj_model
942
 
943
  @torch.inference_mode()
 
947
  clip_image = self.clip_image_processor(
948
  images=pil_image, return_tensors="pt"
949
  ).pixel_values
950
+ clip_image = clip_image.to(self.device, dtype=torch.bfloat16)
951
  clip_image_embeds = self.image_encoder(
952
  clip_image, output_hidden_states=True
953
  ).hidden_states[-2]