Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
f66e674
·
verified ·
1 Parent(s): 481bca0

Upload anytext.py

Browse files
Files changed (1) hide show
  1. anytext.py +229 -15
anytext.py CHANGED
@@ -35,7 +35,6 @@ import PIL.Image
35
  import torch
36
  import torch.nn.functional as F
37
  from easydict import EasyDict as edict
38
- from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
39
  from huggingface_hub import hf_hub_download
40
  from ocr_recog.RecModel import RecModel
41
  from PIL import Image, ImageDraw, ImageFont
@@ -520,6 +519,222 @@ class TextRecognizer(object):
520
  return loss
521
 
522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  class TextEmbeddingModule(nn.Module):
524
  def __init__(self, font_path, use_fp16=False, device="cpu"):
525
  super().__init__()
@@ -1012,7 +1227,7 @@ class AnyTextPipeline(
1012
  Args:
1013
  vae ([`AutoencoderKL`]):
1014
  Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
1015
- text_encoder ([`~transformers.CLIPTextModel`]):
1016
  Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
1017
  tokenizer ([`~transformers.CLIPTokenizer`]):
1018
  A `CLIPTokenizer` to tokenize text.
@@ -1042,26 +1257,25 @@ class AnyTextPipeline(
1042
  self,
1043
  font_path: str,
1044
  vae: AutoencoderKL,
1045
- text_encoder: CLIPTextModel,
1046
  tokenizer: CLIPTokenizer,
1047
  unet: UNet2DConditionModel,
1048
  controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
1049
  scheduler: KarrasDiffusionSchedulers,
1050
  safety_checker: StableDiffusionSafetyChecker,
1051
  feature_extractor: CLIPImageProcessor,
 
1052
  trust_remote_code: bool = False,
1053
- text_embedding_module: TextEmbeddingModule = None,
1054
- auxiliary_latent_module: AuxiliaryLatentModule = None,
1055
  image_encoder: CLIPVisionModelWithProjection = None,
1056
  requires_safety_checker: bool = True,
1057
  ):
1058
  super().__init__()
1059
- self.text_embedding_module = TextEmbeddingModule(
1060
- use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
1061
- )
1062
- self.auxiliary_latent_module = AuxiliaryLatentModule(
1063
- vae=vae, use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
1064
- )
1065
 
1066
  if safety_checker is None and requires_safety_checker:
1067
  logger.warning(
@@ -1092,8 +1306,8 @@ class AnyTextPipeline(
1092
  safety_checker=safety_checker,
1093
  feature_extractor=feature_extractor,
1094
  image_encoder=image_encoder,
1095
- text_embedding_module=self.text_embedding_module,
1096
- auxiliary_latent_module=self.auxiliary_latent_module,
1097
  )
1098
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
1099
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
@@ -1961,7 +2175,7 @@ class AnyTextPipeline(
1961
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1962
  )
1963
  draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos
1964
- prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module(
1965
  prompt,
1966
  texts,
1967
  negative_prompt,
@@ -2203,6 +2417,6 @@ class AnyTextPipeline(
2203
 
2204
  def to(self, *args, **kwargs):
2205
  super().to(*args, **kwargs)
2206
- self.text_embedding_module.to(*args, **kwargs)
2207
  self.auxiliary_latent_module.to(*args, **kwargs)
2208
  return self
 
35
  import torch
36
  import torch.nn.functional as F
37
  from easydict import EasyDict as edict
 
38
  from huggingface_hub import hf_hub_download
39
  from ocr_recog.RecModel import RecModel
40
  from PIL import Image, ImageDraw, ImageFont
 
519
  return loss
520
 
521
 
522
+ import torch
523
+ from torch import nn
524
+ from transformers import CLIPTextModel, CLIPTokenizer
525
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
526
+
527
+
528
+ class AbstractEncoder(nn.Module):
529
+ def __init__(self):
530
+ super().__init__()
531
+
532
+ def encode(self, *args, **kwargs):
533
+ raise NotImplementedError
534
+
535
+
536
+ class FrozenCLIPEmbedderT3(AbstractEncoder):
537
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
538
+
539
+ def __init__(
540
+ self,
541
+ version="openai/clip-vit-large-patch14",
542
+ device="cpu",
543
+ max_length=77,
544
+ freeze=True,
545
+ use_fp16=False,
546
+ ):
547
+ super().__init__()
548
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
549
+ self.transformer = CLIPTextModel.from_pretrained(
550
+ version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32
551
+ ).to(device)
552
+ self.device = device
553
+ self.max_length = max_length
554
+ if freeze:
555
+ self.freeze()
556
+
557
+ def embedding_forward(
558
+ self,
559
+ input_ids=None,
560
+ position_ids=None,
561
+ inputs_embeds=None,
562
+ embedding_manager=None,
563
+ ):
564
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
565
+ if position_ids is None:
566
+ position_ids = self.position_ids[:, :seq_length]
567
+ if inputs_embeds is None:
568
+ inputs_embeds = self.token_embedding(input_ids)
569
+ if embedding_manager is not None:
570
+ inputs_embeds = embedding_manager(input_ids, inputs_embeds)
571
+ position_embeddings = self.position_embedding(position_ids)
572
+ embeddings = inputs_embeds + position_embeddings
573
+ return embeddings
574
+
575
+ self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
576
+ self.transformer.text_model.embeddings
577
+ )
578
+
579
+ def encoder_forward(
580
+ self,
581
+ inputs_embeds,
582
+ attention_mask=None,
583
+ causal_attention_mask=None,
584
+ output_attentions=None,
585
+ output_hidden_states=None,
586
+ return_dict=None,
587
+ ):
588
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
589
+ output_hidden_states = (
590
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
591
+ )
592
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
593
+ encoder_states = () if output_hidden_states else None
594
+ all_attentions = () if output_attentions else None
595
+ hidden_states = inputs_embeds
596
+ for idx, encoder_layer in enumerate(self.layers):
597
+ if output_hidden_states:
598
+ encoder_states = encoder_states + (hidden_states,)
599
+ layer_outputs = encoder_layer(
600
+ hidden_states,
601
+ attention_mask,
602
+ causal_attention_mask,
603
+ output_attentions=output_attentions,
604
+ )
605
+ hidden_states = layer_outputs[0]
606
+ if output_attentions:
607
+ all_attentions = all_attentions + (layer_outputs[1],)
608
+ if output_hidden_states:
609
+ encoder_states = encoder_states + (hidden_states,)
610
+ return hidden_states
611
+
612
+ self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
613
+
614
+ def text_encoder_forward(
615
+ self,
616
+ input_ids=None,
617
+ attention_mask=None,
618
+ position_ids=None,
619
+ output_attentions=None,
620
+ output_hidden_states=None,
621
+ return_dict=None,
622
+ embedding_manager=None,
623
+ ):
624
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
625
+ output_hidden_states = (
626
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
627
+ )
628
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
629
+ if input_ids is None:
630
+ raise ValueError("You have to specify either input_ids")
631
+ input_shape = input_ids.size()
632
+ input_ids = input_ids.view(-1, input_shape[-1])
633
+ hidden_states = self.embeddings(
634
+ input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager
635
+ )
636
+ # CLIP's text model uses causal mask, prepare it here.
637
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
638
+ causal_attention_mask = _create_4d_causal_attention_mask(
639
+ input_shape, hidden_states.dtype, device=hidden_states.device
640
+ )
641
+ # expand attention_mask
642
+ if attention_mask is not None:
643
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
644
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
645
+ last_hidden_state = self.encoder(
646
+ inputs_embeds=hidden_states,
647
+ attention_mask=attention_mask,
648
+ causal_attention_mask=causal_attention_mask,
649
+ output_attentions=output_attentions,
650
+ output_hidden_states=output_hidden_states,
651
+ return_dict=return_dict,
652
+ )
653
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
654
+ return last_hidden_state
655
+
656
+ self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
657
+
658
+ def transformer_forward(
659
+ self,
660
+ input_ids=None,
661
+ attention_mask=None,
662
+ position_ids=None,
663
+ output_attentions=None,
664
+ output_hidden_states=None,
665
+ return_dict=None,
666
+ embedding_manager=None,
667
+ ):
668
+ return self.text_model(
669
+ input_ids=input_ids,
670
+ attention_mask=attention_mask,
671
+ position_ids=position_ids,
672
+ output_attentions=output_attentions,
673
+ output_hidden_states=output_hidden_states,
674
+ return_dict=return_dict,
675
+ embedding_manager=embedding_manager,
676
+ )
677
+
678
+ self.transformer.forward = transformer_forward.__get__(self.transformer)
679
+
680
+ def freeze(self):
681
+ self.transformer = self.transformer.eval()
682
+ for param in self.parameters():
683
+ param.requires_grad = False
684
+
685
+ def forward(self, text, **kwargs):
686
+ batch_encoding = self.tokenizer(
687
+ text,
688
+ truncation=False,
689
+ max_length=self.max_length,
690
+ return_length=True,
691
+ return_overflowing_tokens=False,
692
+ padding="longest",
693
+ return_tensors="pt",
694
+ )
695
+ input_ids = batch_encoding["input_ids"]
696
+ tokens_list = self.split_chunks(input_ids)
697
+ z_list = []
698
+ for tokens in tokens_list:
699
+ tokens = tokens.to(self.device)
700
+ _z = self.transformer(input_ids=tokens, **kwargs)
701
+ z_list += [_z]
702
+ return torch.cat(z_list, dim=1)
703
+
704
+ def encode(self, text, **kwargs):
705
+ return self(text, **kwargs)
706
+
707
+ def split_chunks(self, input_ids, chunk_size=75):
708
+ tokens_list = []
709
+ bs, n = input_ids.shape
710
+ id_start = input_ids[:, 0].unsqueeze(1) # dim --> [bs, 1]
711
+ id_end = input_ids[:, -1].unsqueeze(1)
712
+ if n == 2: # empty caption
713
+ tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1))
714
+
715
+ trimmed_encoding = input_ids[:, 1:-1]
716
+ num_full_groups = (n - 2) // chunk_size
717
+
718
+ for i in range(num_full_groups):
719
+ group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size]
720
+ group_pad = torch.cat((id_start, group, id_end), dim=1)
721
+ tokens_list.append(group_pad)
722
+
723
+ remaining_columns = (n - 2) % chunk_size
724
+ if remaining_columns > 0:
725
+ remaining_group = trimmed_encoding[:, -remaining_columns:]
726
+ padding_columns = chunk_size - remaining_group.shape[1]
727
+ padding = id_end.expand(bs, padding_columns)
728
+ remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1)
729
+ tokens_list.append(remaining_group_pad)
730
+ return tokens_list
731
+
732
+ def to(self, *args, **kwargs):
733
+ self.transformer = self.transformer.to(*args, **kwargs)
734
+ self.device = self.transformer.device
735
+ return self
736
+
737
+
738
  class TextEmbeddingModule(nn.Module):
739
  def __init__(self, font_path, use_fp16=False, device="cpu"):
740
  super().__init__()
 
1227
  Args:
1228
  vae ([`AutoencoderKL`]):
1229
  Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
1230
+ text_encoder ([`~anytext.TextEmbeddingModule`]):
1231
  Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
1232
  tokenizer ([`~transformers.CLIPTokenizer`]):
1233
  A `CLIPTokenizer` to tokenize text.
 
1257
  self,
1258
  font_path: str,
1259
  vae: AutoencoderKL,
1260
+ text_encoder: TextEmbeddingModule,
1261
  tokenizer: CLIPTokenizer,
1262
  unet: UNet2DConditionModel,
1263
  controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
1264
  scheduler: KarrasDiffusionSchedulers,
1265
  safety_checker: StableDiffusionSafetyChecker,
1266
  feature_extractor: CLIPImageProcessor,
1267
+ auxiliary_latent_module: AuxiliaryLatentModule,
1268
  trust_remote_code: bool = False,
 
 
1269
  image_encoder: CLIPVisionModelWithProjection = None,
1270
  requires_safety_checker: bool = True,
1271
  ):
1272
  super().__init__()
1273
+ # self.text_embedding_module = TextEmbeddingModule(
1274
+ # use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
1275
+ # )
1276
+ # self.auxiliary_latent_module = AuxiliaryLatentModule(
1277
+ # vae=vae, use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
1278
+ # )
1279
 
1280
  if safety_checker is None and requires_safety_checker:
1281
  logger.warning(
 
1306
  safety_checker=safety_checker,
1307
  feature_extractor=feature_extractor,
1308
  image_encoder=image_encoder,
1309
+ # text_embedding_module=self.text_embedding_module,
1310
+ auxiliary_latent_module=auxiliary_latent_module,
1311
  )
1312
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
1313
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
 
2175
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
2176
  )
2177
  draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos
2178
+ prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_encoder(
2179
  prompt,
2180
  texts,
2181
  negative_prompt,
 
2417
 
2418
  def to(self, *args, **kwargs):
2419
  super().to(*args, **kwargs)
2420
+ # self.text_embedding_module.to(*args, **kwargs)
2421
  self.auxiliary_latent_module.to(*args, **kwargs)
2422
  return self