Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
a4686b9
·
verified ·
1 Parent(s): 225e8aa

Upload anytext.py

Browse files
Files changed (1) hide show
  1. auxiliary_latent_module/anytext.py +131 -34
auxiliary_latent_module/anytext.py CHANGED
@@ -25,6 +25,7 @@ import math
25
  import os
26
  import re
27
  import sys
 
28
  from functools import partial
29
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
30
 
@@ -33,9 +34,9 @@ import numpy as np
33
  import PIL.Image
34
  import torch
35
  import torch.nn.functional as F
36
- from bert_tokenizer import BasicTokenizer
37
  from easydict import EasyDict as edict
38
  from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
 
39
  from ocr_recog.RecModel import RecModel
40
  from PIL import Image, ImageDraw, ImageFont
41
  from safetensors.torch import load_file
@@ -66,12 +67,75 @@ from diffusers.utils import (
66
  scale_lora_layers,
67
  unscale_lora_layers,
68
  )
 
69
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
70
- from diffusers.configuration_utils import register_to_config, ConfigMixin
71
- from diffusers.models.modeling_utils import ModelMixin
72
 
73
 
74
- checker = BasicTokenizer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
 
77
  PLACE_HOLDER = "*"
@@ -81,18 +145,22 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
81
  EXAMPLE_DOC_STRING = """
82
  Examples:
83
  ```py
84
- >>> from pipeline_anytext import AnyTextPipeline
 
85
  >>> from anytext_controlnet import AnyTextControlNetModel
86
  >>> from diffusers import DDIMScheduler
87
  >>> from diffusers.utils import load_image
88
- >>> import torch
 
 
89
 
90
  >>> # load control net and stable diffusion v1-5
91
- >>> text_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
92
- ... variant="fp16",)
93
- >>> pipe = AnyTextPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet,
94
- ... torch_dtype=torch.float16, variant="fp16",
95
- ... ).to("cuda")
 
96
 
97
  >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
98
  >>> # uncomment following line if PyTorch>=2.0 is not installed for memory optimization
@@ -103,11 +171,9 @@ EXAMPLE_DOC_STRING = """
103
  >>> #pipe.enable_model_cpu_offload()
104
 
105
  >>> # generate image
106
- >>> generator = torch.Generator("cpu").manual_seed(66273235)
107
  >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
108
- >>> draw_pos = load_image("www.huggingface.co/a/AnyText/tree/main/examples/gen9.png")
109
- >>> image = pipe(prompt, num_inference_steps=20, generator=generator, mode="generate",
110
- ... draw_pos=draw_pos,
111
  ... ).images[0]
112
  >>> image
113
  ```
@@ -152,7 +218,12 @@ class EmbeddingManager(nn.Module):
152
  self.token_dim = token_dim
153
 
154
  self.proj = nn.Linear(40 * 64, token_dim)
155
- # self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device)))
 
 
 
 
 
156
  if use_fp16:
157
  self.proj = self.proj.to(dtype=torch.float16)
158
 
@@ -269,9 +340,14 @@ def crop_image(src_img, mask):
269
 
270
 
271
  def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
272
- model_file_path = model_dir
273
- if model_file_path is not None and not os.path.exists(model_file_path):
274
- raise ValueError("not find model file path {}".format(model_file_path))
 
 
 
 
 
275
 
276
  if model_lang == "ch":
277
  n_class = 6625
@@ -287,8 +363,8 @@ def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=Fal
287
  )
288
 
289
  rec_model = RecModel(rec_config)
290
- if model_file_path is not None:
291
- rec_model.load_state_dict(torch.load(model_file_path, map_location=device))
292
  return rec_model
293
 
294
 
@@ -401,7 +477,7 @@ class TextRecognizer(object):
401
  preds["ctc"] = torch.from_numpy(outputs[0])
402
  preds["ctc_neck"] = [torch.zeros(1)] * img_num
403
  else:
404
- preds = self.predictor(norm_img_batch)
405
  for rno in range(preds["ctc"].shape[0]):
406
  preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno]
407
  preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno]
@@ -450,21 +526,28 @@ class TextRecognizer(object):
450
  return loss
451
 
452
 
453
- class TextEmbeddingModule(ModelMixin, ConfigMixin):
454
- @register_to_config
455
  def __init__(self, font_path, use_fp16=False, device="cpu"):
456
  super().__init__()
457
  # TODO: Learn if the recommended font file is free to use
458
  self.font = ImageFont.truetype(font_path, 60)
 
 
459
  self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
460
  self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
461
- rec_model_dir = "./OCR/ppv3_rec.pth"
462
  self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
463
  args = {}
464
  args["rec_image_shape"] = "3, 48, 320"
465
  args["rec_batch_num"] = 6
466
- args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt"
467
- args["use_fp16"] = self.use_fp16
 
 
 
 
 
468
  self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
469
 
470
  @torch.no_grad()
@@ -487,7 +570,10 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
487
  # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
488
  if draw_pos is None:
489
  pos_imgs = np.zeros((w, h, 1))
490
- if isinstance(draw_pos, str):
 
 
 
491
  draw_pos = cv2.imread(draw_pos)[..., ::-1]
492
  if draw_pos is None:
493
  raise ValueError(f"Can't read draw_pos image from {draw_pos}!")
@@ -580,7 +666,7 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
580
 
581
  self.embedding_manager.encode_text(text_info)
582
  negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode(
583
- [negative_prompt], embedding_manager=self.embedding_manager
584
  )
585
 
586
  return prompt_embeds, negative_prompt_embeds, text_info, np_hint
@@ -799,7 +885,8 @@ class AuxiliaryLatentModule(nn.Module):
799
  # get masked_x
800
  masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
801
  masked_img = np.transpose(masked_img, (2, 0, 1))
802
- masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
 
803
  if self.use_fp16:
804
  masked_img = masked_img.half()
805
  masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach()
@@ -842,9 +929,9 @@ class AuxiliaryLatentModule(nn.Module):
842
  new_string += char + " " * nSpace
843
  return new_string[:-nSpace]
844
 
845
- def to(self, device):
846
- self.device = device
847
- self.vae = self.vae.to(device)
848
  return self
849
 
850
 
@@ -969,6 +1056,9 @@ class AnyTextPipeline(
969
  scheduler: KarrasDiffusionSchedulers,
970
  safety_checker: StableDiffusionSafetyChecker,
971
  feature_extractor: CLIPImageProcessor,
 
 
 
972
  image_encoder: CLIPVisionModelWithProjection = None,
973
  requires_safety_checker: bool = True,
974
  ):
@@ -1877,6 +1967,7 @@ class AnyTextPipeline(
1877
  text_encoder_lora_scale = (
1878
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1879
  )
 
1880
  prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module(
1881
  prompt,
1882
  texts,
@@ -2035,7 +2126,7 @@ class AnyTextPipeline(
2035
  control_model_input,
2036
  t,
2037
  encoder_hidden_states=controlnet_prompt_embeds,
2038
- guided_hint=guided_hint,
2039
  conditioning_scale=cond_scale,
2040
  guess_mode=guess_mode,
2041
  return_dict=False,
@@ -2116,3 +2207,9 @@ class AnyTextPipeline(
2116
  return (image, has_nsfw_concept)
2117
 
2118
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
 
 
 
25
  import os
26
  import re
27
  import sys
28
+ import unicodedata
29
  from functools import partial
30
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
31
 
 
34
  import PIL.Image
35
  import torch
36
  import torch.nn.functional as F
 
37
  from easydict import EasyDict as edict
38
  from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
39
+ from huggingface_hub import hf_hub_download
40
  from ocr_recog.RecModel import RecModel
41
  from PIL import Image, ImageDraw, ImageFont
42
  from safetensors.torch import load_file
 
67
  scale_lora_layers,
68
  unscale_lora_layers,
69
  )
70
+ from diffusers.utils.constants import HF_MODULES_CACHE
71
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
 
 
72
 
73
 
74
+ class Checker:
75
+ def __init__(self):
76
+ pass
77
+
78
+ def _is_chinese_char(self, cp):
79
+ """Checks whether CP is the codepoint of a CJK character."""
80
+ # This defines a "chinese character" as anything in the CJK Unicode block:
81
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
82
+ #
83
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
84
+ # despite its name. The modern Korean Hangul alphabet is a different block,
85
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
86
+ # space-separated words, so they are not treated specially and handled
87
+ # like the all of the other languages.
88
+ if (
89
+ (cp >= 0x4E00 and cp <= 0x9FFF)
90
+ or (cp >= 0x3400 and cp <= 0x4DBF)
91
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
92
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
93
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
94
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
95
+ or (cp >= 0xF900 and cp <= 0xFAFF)
96
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
97
+ ):
98
+ return True
99
+
100
+ return False
101
+
102
+ def _clean_text(self, text):
103
+ """Performs invalid character removal and whitespace cleanup on text."""
104
+ output = []
105
+ for char in text:
106
+ cp = ord(char)
107
+ if cp == 0 or cp == 0xFFFD or self._is_control(char):
108
+ continue
109
+ if self._is_whitespace(char):
110
+ output.append(" ")
111
+ else:
112
+ output.append(char)
113
+ return "".join(output)
114
+
115
+ def _is_control(self, char):
116
+ """Checks whether `chars` is a control character."""
117
+ # These are technically control characters but we count them as whitespace
118
+ # characters.
119
+ if char == "\t" or char == "\n" or char == "\r":
120
+ return False
121
+ cat = unicodedata.category(char)
122
+ if cat in ("Cc", "Cf"):
123
+ return True
124
+ return False
125
+
126
+ def _is_whitespace(self, char):
127
+ """Checks whether `chars` is a whitespace character."""
128
+ # \t, \n, and \r are technically control characters but we treat them
129
+ # as whitespace since they are generally considered as such.
130
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
131
+ return True
132
+ cat = unicodedata.category(char)
133
+ if cat == "Zs":
134
+ return True
135
+ return False
136
+
137
+
138
+ checker = Checker()
139
 
140
 
141
  PLACE_HOLDER = "*"
 
145
  EXAMPLE_DOC_STRING = """
146
  Examples:
147
  ```py
148
+ >>> import torch
149
+ >>> from diffusers import DiffusionPipeline
150
  >>> from anytext_controlnet import AnyTextControlNetModel
151
  >>> from diffusers import DDIMScheduler
152
  >>> from diffusers.utils import load_image
153
+
154
+ >>> # I chose a font file shared by an HF staff:
155
+ >>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
156
 
157
  >>> # load control net and stable diffusion v1-5
158
+ >>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
159
+ ... variant="fp16",)
160
+ >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
161
+ ... controlnet=anytext_controlnet, torch_dtype=torch.float16,
162
+ ... trust_remote_code=True,
163
+ ... ).to("cuda")
164
 
165
  >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
166
  >>> # uncomment following line if PyTorch>=2.0 is not installed for memory optimization
 
171
  >>> #pipe.enable_model_cpu_offload()
172
 
173
  >>> # generate image
 
174
  >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
175
+ >>> draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
176
+ >>> image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
 
177
  ... ).images[0]
178
  >>> image
179
  ```
 
218
  self.token_dim = token_dim
219
 
220
  self.proj = nn.Linear(40 * 64, token_dim)
221
+ proj_dir = hf_hub_download(
222
+ repo_id="tolgacangoz/anytext",
223
+ filename="text_embedding_module/proj.safetensors",
224
+ cache_dir=HF_MODULES_CACHE,
225
+ )
226
+ self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))
227
  if use_fp16:
228
  self.proj = self.proj.to(dtype=torch.float16)
229
 
 
340
 
341
 
342
  def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
343
+ if model_dir is None or not os.path.exists(model_dir):
344
+ model_dir = hf_hub_download(
345
+ repo_id="tolgacangoz/anytext",
346
+ filename="text_embedding_module/OCR/ppv3_rec.pth",
347
+ cache_dir=HF_MODULES_CACHE,
348
+ )
349
+ if not os.path.exists(model_dir):
350
+ raise ValueError("not find model file path {}".format(model_dir))
351
 
352
  if model_lang == "ch":
353
  n_class = 6625
 
363
  )
364
 
365
  rec_model = RecModel(rec_config)
366
+ state_dict = torch.load(model_dir, map_location=device)
367
+ rec_model.load_state_dict(state_dict)
368
  return rec_model
369
 
370
 
 
477
  preds["ctc"] = torch.from_numpy(outputs[0])
478
  preds["ctc_neck"] = [torch.zeros(1)] * img_num
479
  else:
480
+ preds = self.predictor(norm_img_batch.to(next(self.predictor.parameters()).device))
481
  for rno in range(preds["ctc"].shape[0]):
482
  preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno]
483
  preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno]
 
526
  return loss
527
 
528
 
529
+ class TextEmbeddingModule(nn.Module):
530
+ # @register_to_config
531
  def __init__(self, font_path, use_fp16=False, device="cpu"):
532
  super().__init__()
533
  # TODO: Learn if the recommended font file is free to use
534
  self.font = ImageFont.truetype(font_path, 60)
535
+ self.use_fp16 = use_fp16
536
+ self.device = device
537
  self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
538
  self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
539
+ rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
540
  self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
541
  args = {}
542
  args["rec_image_shape"] = "3, 48, 320"
543
  args["rec_batch_num"] = 6
544
+ args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt"
545
+ args["rec_char_dict_path"] = hf_hub_download(
546
+ repo_id="tolgacangoz/anytext",
547
+ filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
548
+ cache_dir=HF_MODULES_CACHE,
549
+ )
550
+ args["use_fp16"] = use_fp16
551
  self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
552
 
553
  @torch.no_grad()
 
570
  # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
571
  if draw_pos is None:
572
  pos_imgs = np.zeros((w, h, 1))
573
+ if isinstance(draw_pos, PIL.Image.Image):
574
+ pos_imgs = np.array(draw_pos)[..., ::-1]
575
+ pos_imgs = 255 - pos_imgs
576
+ elif isinstance(draw_pos, str):
577
  draw_pos = cv2.imread(draw_pos)[..., ::-1]
578
  if draw_pos is None:
579
  raise ValueError(f"Can't read draw_pos image from {draw_pos}!")
 
666
 
667
  self.embedding_manager.encode_text(text_info)
668
  negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode(
669
+ [negative_prompt or ""], embedding_manager=self.embedding_manager
670
  )
671
 
672
  return prompt_embeds, negative_prompt_embeds, text_info, np_hint
 
885
  # get masked_x
886
  masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
887
  masked_img = np.transpose(masked_img, (2, 0, 1))
888
+ device = next(self.vae.parameters()).device
889
+ masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
890
  if self.use_fp16:
891
  masked_img = masked_img.half()
892
  masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach()
 
929
  new_string += char + " " * nSpace
930
  return new_string[:-nSpace]
931
 
932
+ def to(self, *args, **kwargs):
933
+ self.vae = self.vae.to(*args, **kwargs)
934
+ self.device = self.vae.device
935
  return self
936
 
937
 
 
1056
  scheduler: KarrasDiffusionSchedulers,
1057
  safety_checker: StableDiffusionSafetyChecker,
1058
  feature_extractor: CLIPImageProcessor,
1059
+ trust_remote_code: bool = False,
1060
+ text_embedding_module: TextEmbeddingModule = None,
1061
+ auxiliary_latent_module: AuxiliaryLatentModule = None,
1062
  image_encoder: CLIPVisionModelWithProjection = None,
1063
  requires_safety_checker: bool = True,
1064
  ):
 
1967
  text_encoder_lora_scale = (
1968
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1969
  )
1970
+ draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos
1971
  prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module(
1972
  prompt,
1973
  texts,
 
2126
  control_model_input,
2127
  t,
2128
  encoder_hidden_states=controlnet_prompt_embeds,
2129
+ controlnet_cond=guided_hint,
2130
  conditioning_scale=cond_scale,
2131
  guess_mode=guess_mode,
2132
  return_dict=False,
 
2207
  return (image, has_nsfw_concept)
2208
 
2209
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
2210
+
2211
+ def to(self, *args, **kwargs):
2212
+ super().to(*args, **kwargs)
2213
+ self.text_embedding_module.to(*args, **kwargs)
2214
+ self.auxiliary_latent_module.to(*args, **kwargs)
2215
+ return self