Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
c75241d
·
verified ·
1 Parent(s): 24482b9

Upload anytext.py

Browse files
Files changed (1) hide show
  1. anytext.py +47 -77
anytext.py CHANGED
@@ -41,8 +41,10 @@ from safetensors.torch import load_file
41
  from skimage.transform._geometric import _umeyama as get_sym_mat
42
  from torch import nn
43
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
 
44
 
45
  from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
 
46
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
47
  from diffusers.loaders import (
48
  FromSingleFileMixin,
@@ -52,13 +54,12 @@ from diffusers.loaders import (
52
  )
53
  from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
54
  from diffusers.models.lora import adjust_lora_scale_text_encoder
 
55
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
56
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
57
  from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
58
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
59
  from diffusers.schedulers import KarrasDiffusionSchedulers
60
- from diffusers.configuration_utils import register_to_config, ConfigMixin
61
- from diffusers.models.modeling_utils import ModelMixin
62
  from diffusers.utils import (
63
  USE_PEFT_BACKEND,
64
  deprecate,
@@ -154,21 +155,14 @@ EXAMPLE_DOC_STRING = """
154
  >>> # I chose a font file shared by an HF staff:
155
  >>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
156
 
157
- >>> # load control net and stable diffusion v1-5
158
  >>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
159
  ... variant="fp16",)
160
  >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
161
  ... controlnet=anytext_controlnet, torch_dtype=torch.float16,
162
- ... trust_remote_code=True,
163
  ... ).to("cuda")
164
 
165
  >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
166
- >>> # uncomment following line if PyTorch>=2.0 is not installed for memory optimization
167
- >>> #pipe.enable_xformers_memory_efficient_attention()
168
-
169
- >>> # uncomment following line if you want to offload the model to CPU for memory optimization
170
- >>> # also remove the `.to("cuda")` part
171
- >>> #pipe.enable_model_cpu_offload()
172
 
173
  >>> # generate image
174
  >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
@@ -211,8 +205,8 @@ class EmbeddingManager(ModelMixin, ConfigMixin):
211
  embedder,
212
  placeholder_string="*",
213
  use_fp16=False,
214
- token_dim = 768,
215
- get_recog_emb = None,
216
  ):
217
  super().__init__()
218
  get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
@@ -227,9 +221,7 @@ class EmbeddingManager(ModelMixin, ConfigMixin):
227
  if use_fp16:
228
  self.proj = self.proj.to(dtype=torch.float16)
229
 
230
- # self.register_parameter("proj", proj)
231
  self.placeholder_token = get_token_for_string(placeholder_string)
232
- # self.register_config(placeholder_token=placeholder_token)
233
 
234
  @torch.no_grad()
235
  def encode_text(self, text_info):
@@ -350,12 +342,19 @@ def create_predictor(model_lang="ch", device="cpu", use_fp16=False):
350
  n_class = 97
351
  else:
352
  raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
353
- rec_config = dict(
354
- in_channels=3,
355
- backbone=dict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"),
356
- neck=dict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True),
357
- head=dict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True),
358
- )
 
 
 
 
 
 
 
359
 
360
  rec_model = RecModel(rec_config)
361
  state_dict = torch.load(model_dir, map_location=device)
@@ -521,12 +520,6 @@ class TextRecognizer(object):
521
  return loss
522
 
523
 
524
- import torch
525
- from torch import nn
526
- from transformers import CLIPTextModel, CLIPTokenizer
527
- from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
528
-
529
-
530
  class AbstractEncoder(nn.Module):
531
  def __init__(self):
532
  super().__init__()
@@ -537,6 +530,7 @@ class AbstractEncoder(nn.Module):
537
 
538
  class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
539
  """Uses the CLIP transformer encoder for text (from Hugging Face)"""
 
540
  @register_to_config
541
  def __init__(
542
  self,
@@ -548,11 +542,13 @@ class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
548
  ):
549
  super().__init__()
550
  self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
551
- self.transformer = CLIPTextModel.from_pretrained("tolgacangoz/anytext", subfolder="text_encoder",
552
- torch_dtype=torch.float16 if use_fp16 else torch.float32,
553
- variant="fp16" if use_fp16 else None)
554
- # self.device = device
555
- # self.max_length = max_length
 
 
556
  if freeze:
557
  self.freeze()
558
 
@@ -731,11 +727,6 @@ class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
731
  tokens_list.append(remaining_group_pad)
732
  return tokens_list
733
 
734
- # def to(self, *args, **kwargs):
735
- # self.transformer = self.transformer.to(*args, **kwargs)
736
- # self.device = self.transformer.device
737
- # return self
738
-
739
 
740
  class TextEmbeddingModule(ModelMixin, ConfigMixin):
741
  @register_to_config
@@ -743,25 +734,21 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
743
  super().__init__()
744
  font = ImageFont.truetype(font_path, 60)
745
 
746
- # self.use_fp16 = use_fp16
747
- # self.device = device
748
  self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
749
  self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
750
  self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()
751
- args = {"rec_image_shape": "3, 48, 320",
752
- "rec_batch_num": 6,
753
- "rec_char_dict_path": hf_hub_download(
754
- repo_id="tolgacangoz/anytext",
755
- filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
756
- cache_dir=HF_MODULES_CACHE,
757
- ),
758
- "use_fp16": use_fp16}
 
 
759
  self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
760
 
761
- # self.register_modules(
762
- # frozen_CLIP_embedder_t3=frozen_CLIP_embedder_t3,
763
- # embedding_manager=embedding_manager,
764
- # )
765
  self.register_to_config(font=font)
766
 
767
  @torch.no_grad()
@@ -873,8 +860,6 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
873
  text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)]
874
  text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)]
875
 
876
- # hint = self.arr2tensor(np_hint, len(prompt))
877
-
878
  self.embedding_manager.encode_text(text_info)
879
  prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager)
880
 
@@ -1028,11 +1013,6 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
1028
  new_string += char + " " * nSpace
1029
  return new_string[:-nSpace]
1030
 
1031
- # def to(self, *args, **kwargs):
1032
- # self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
1033
- # self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
1034
- # return self
1035
-
1036
 
1037
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
1038
  def retrieve_latents(
@@ -1052,13 +1032,10 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
1052
  @register_to_config
1053
  def __init__(
1054
  self,
1055
- # font_path,
1056
  vae,
1057
  device="cpu",
1058
  ):
1059
  super().__init__()
1060
- # self.font = ImageFont.truetype(font_path, 60)
1061
- # self.vae = vae.eval() if vae is not None else None
1062
 
1063
  @torch.no_grad()
1064
  def forward(
@@ -1100,7 +1077,9 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
1100
  masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
1101
  if dtype == torch.float16:
1102
  masked_img = masked_img.half()
1103
- masked_x = (retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor).detach()
 
 
1104
  if dtype == torch.float16:
1105
  masked_x = masked_x.half()
1106
  text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
@@ -1140,11 +1119,6 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
1140
  new_string += char + " " * nSpace
1141
  return new_string[:-nSpace]
1142
 
1143
- # def to(self, *args, **kwargs):
1144
- # self.vae = self.vae.to(*args, **kwargs)
1145
- # self.device = self.vae.device
1146
- # return self
1147
-
1148
 
1149
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
1150
  def retrieve_timesteps(
@@ -1266,7 +1240,7 @@ class AnyTextPipeline(
1266
  scheduler: KarrasDiffusionSchedulers,
1267
  safety_checker: StableDiffusionSafetyChecker,
1268
  feature_extractor: CLIPImageProcessor,
1269
- font_path: str = "arial-unicode-ms.ttf",
1270
  text_embedding_module: Optional[TextEmbeddingModule] = None,
1271
  auxiliary_latent_module: Optional[AuxiliaryLatentModule] = None,
1272
  trust_remote_code: bool = False,
@@ -1274,15 +1248,11 @@ class AnyTextPipeline(
1274
  requires_safety_checker: bool = True,
1275
  ):
1276
  super().__init__()
1277
- text_embedding_module = TextEmbeddingModule(
1278
- font_path=font_path,
1279
- use_fp16=unet.dtype == torch.float16,
1280
- )
1281
- auxiliary_latent_module = AuxiliaryLatentModule(
1282
- # font_path=font_path,
1283
- vae=vae,
1284
- # use_fp16=unet.dtype == torch.float16,
1285
- )
1286
 
1287
  if safety_checker is None and requires_safety_checker:
1288
  logger.warning(
@@ -1321,7 +1291,7 @@ class AnyTextPipeline(
1321
  self.control_image_processor = VaeImageProcessor(
1322
  vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
1323
  )
1324
- self.register_to_config(requires_safety_checker=requires_safety_checker)#, font_path=font_path)
1325
 
1326
  def modify_prompt(self, prompt):
1327
  prompt = prompt.replace("“", '"')
 
41
  from skimage.transform._geometric import _umeyama as get_sym_mat
42
  from torch import nn
43
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
44
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
45
 
46
  from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
47
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
48
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
49
  from diffusers.loaders import (
50
  FromSingleFileMixin,
 
54
  )
55
  from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
56
  from diffusers.models.lora import adjust_lora_scale_text_encoder
57
+ from diffusers.models.modeling_utils import ModelMixin
58
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
59
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
60
  from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
61
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
62
  from diffusers.schedulers import KarrasDiffusionSchedulers
 
 
63
  from diffusers.utils import (
64
  USE_PEFT_BACKEND,
65
  deprecate,
 
155
  >>> # I chose a font file shared by an HF staff:
156
  >>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
157
 
 
158
  >>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
159
  ... variant="fp16",)
160
  >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
161
  ... controlnet=anytext_controlnet, torch_dtype=torch.float16,
162
+ ... trust_remote_code=False, # One needs to give permission to run this pipeline's code
163
  ... ).to("cuda")
164
 
165
  >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
 
166
 
167
  >>> # generate image
168
  >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
 
205
  embedder,
206
  placeholder_string="*",
207
  use_fp16=False,
208
+ token_dim=768,
209
+ get_recog_emb=None,
210
  ):
211
  super().__init__()
212
  get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
 
221
  if use_fp16:
222
  self.proj = self.proj.to(dtype=torch.float16)
223
 
 
224
  self.placeholder_token = get_token_for_string(placeholder_string)
 
225
 
226
  @torch.no_grad()
227
  def encode_text(self, text_info):
 
342
  n_class = 97
343
  else:
344
  raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
345
+ rec_config = {
346
+ "in_channels": 3,
347
+ "backbone": {"type": "MobileNetV1Enhance", "scale": 0.5, "last_conv_stride": [1, 2], "last_pool_type": "avg"},
348
+ "neck": {
349
+ "type": "SequenceEncoder",
350
+ "encoder_type": "svtr",
351
+ "dims": 64,
352
+ "depth": 2,
353
+ "hidden_dims": 120,
354
+ "use_guide": True,
355
+ },
356
+ "head": {"type": "CTCHead", "fc_decay": 0.00001, "out_channels": n_class, "return_feats": True},
357
+ }
358
 
359
  rec_model = RecModel(rec_config)
360
  state_dict = torch.load(model_dir, map_location=device)
 
520
  return loss
521
 
522
 
 
 
 
 
 
 
523
  class AbstractEncoder(nn.Module):
524
  def __init__(self):
525
  super().__init__()
 
530
 
531
  class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
532
  """Uses the CLIP transformer encoder for text (from Hugging Face)"""
533
+
534
  @register_to_config
535
  def __init__(
536
  self,
 
542
  ):
543
  super().__init__()
544
  self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
545
+ self.transformer = CLIPTextModel.from_pretrained(
546
+ "tolgacangoz/anytext",
547
+ subfolder="text_encoder",
548
+ torch_dtype=torch.float16 if use_fp16 else torch.float32,
549
+ variant="fp16" if use_fp16 else None,
550
+ )
551
+
552
  if freeze:
553
  self.freeze()
554
 
 
727
  tokens_list.append(remaining_group_pad)
728
  return tokens_list
729
 
 
 
 
 
 
730
 
731
  class TextEmbeddingModule(ModelMixin, ConfigMixin):
732
  @register_to_config
 
734
  super().__init__()
735
  font = ImageFont.truetype(font_path, 60)
736
 
 
 
737
  self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
738
  self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
739
  self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()
740
+ args = {
741
+ "rec_image_shape": "3, 48, 320",
742
+ "rec_batch_num": 6,
743
+ "rec_char_dict_path": hf_hub_download(
744
+ repo_id="tolgacangoz/anytext",
745
+ filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
746
+ cache_dir=HF_MODULES_CACHE,
747
+ ),
748
+ "use_fp16": use_fp16,
749
+ }
750
  self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
751
 
 
 
 
 
752
  self.register_to_config(font=font)
753
 
754
  @torch.no_grad()
 
860
  text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)]
861
  text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)]
862
 
 
 
863
  self.embedding_manager.encode_text(text_info)
864
  prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager)
865
 
 
1013
  new_string += char + " " * nSpace
1014
  return new_string[:-nSpace]
1015
 
 
 
 
 
 
1016
 
1017
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
1018
  def retrieve_latents(
 
1032
  @register_to_config
1033
  def __init__(
1034
  self,
 
1035
  vae,
1036
  device="cpu",
1037
  ):
1038
  super().__init__()
 
 
1039
 
1040
  @torch.no_grad()
1041
  def forward(
 
1077
  masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
1078
  if dtype == torch.float16:
1079
  masked_img = masked_img.half()
1080
+ masked_x = (
1081
+ retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor
1082
+ ).detach()
1083
  if dtype == torch.float16:
1084
  masked_x = masked_x.half()
1085
  text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
 
1119
  new_string += char + " " * nSpace
1120
  return new_string[:-nSpace]
1121
 
 
 
 
 
 
1122
 
1123
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
1124
  def retrieve_timesteps(
 
1240
  scheduler: KarrasDiffusionSchedulers,
1241
  safety_checker: StableDiffusionSafetyChecker,
1242
  feature_extractor: CLIPImageProcessor,
1243
+ font_path: str = None,
1244
  text_embedding_module: Optional[TextEmbeddingModule] = None,
1245
  auxiliary_latent_module: Optional[AuxiliaryLatentModule] = None,
1246
  trust_remote_code: bool = False,
 
1248
  requires_safety_checker: bool = True,
1249
  ):
1250
  super().__init__()
1251
+ if font_path is None:
1252
+ raise ValueError("font_path is required!")
1253
+
1254
+ text_embedding_module = TextEmbeddingModule(font_path=font_path, use_fp16=unet.dtype == torch.float16)
1255
+ auxiliary_latent_module = AuxiliaryLatentModule(vae=vae)
 
 
 
 
1256
 
1257
  if safety_checker is None and requires_safety_checker:
1258
  logger.warning(
 
1291
  self.control_image_processor = VaeImageProcessor(
1292
  vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
1293
  )
1294
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
1295
 
1296
  def modify_prompt(self, prompt):
1297
  prompt = prompt.replace("“", '"')