Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
a113455
·
verified ·
1 Parent(s): b4b20cb

Upload anytext.py

Browse files
Files changed (1) hide show
  1. anytext.py +12 -62
anytext.py CHANGED
@@ -35,6 +35,7 @@ import PIL.Image
35
  import torch
36
  import torch.nn.functional as F
37
  from easydict import EasyDict as edict
 
38
  from huggingface_hub import hf_hub_download
39
  from ocr_recog.RecModel import RecModel
40
  from PIL import Image, ImageDraw, ImageFont
@@ -206,13 +207,12 @@ def get_recog_emb(encoder, img_list):
206
  class EmbeddingManager(nn.Module):
207
  def __init__(
208
  self,
209
- clip_tokenizer,
210
  placeholder_string="*",
211
  use_fp16=False,
212
- device="cpu",
213
  ):
214
  super().__init__()
215
- get_token_for_string = partial(get_clip_token_for_string, clip_tokenizer)
216
  token_dim = 768
217
  self.get_recog_emb = None
218
  self.token_dim = token_dim
@@ -223,7 +223,7 @@ class EmbeddingManager(nn.Module):
223
  filename="text_embedding_module/proj.safetensors",
224
  cache_dir=HF_MODULES_CACHE,
225
  )
226
- self.proj.load_state_dict(load_file(proj_dir, device=str(device)))
227
  if use_fp16:
228
  self.proj = self.proj.to(dtype=torch.float16)
229
 
@@ -526,20 +526,14 @@ class TextEmbeddingModule(nn.Module):
526
  self.font = ImageFont.truetype(font_path, 60)
527
  self.use_fp16 = use_fp16
528
  self.device = device
529
-
530
- # Replace instantiation of frozen_CLIP_embedder_t3
531
- version = "openai/clip-vit-large-patch14"
532
- torch_dtype = torch.float16 if use_fp16 else torch.float32
533
- self.clip_tokenizer = CLIPTokenizer.from_pretrained(version)
534
- self.clip_text_model = CLIPTextModel.from_pretrained(version, torch_dtype=torch_dtype).to(device)
535
- self.max_length = 77 # same as before
536
-
537
- self.embedding_manager = EmbeddingManager(self.clip_tokenizer, use_fp16=use_fp16, device=device)
538
  rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
539
  self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
540
  args = {}
541
  args["rec_image_shape"] = "3, 48, 320"
542
  args["rec_batch_num"] = 6
 
543
  args["rec_char_dict_path"] = hf_hub_download(
544
  repo_id="tolgacangoz/anytext",
545
  filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
@@ -548,50 +542,6 @@ class TextEmbeddingModule(nn.Module):
548
  args["use_fp16"] = use_fp16
549
  self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
550
 
551
- # New helper method to mimic old encode() functionality with chunk splitting
552
- def _encode_text(self, texts, embedding_manager=None, **kwargs):
553
- batch_encoding = self.clip_tokenizer(
554
- texts,
555
- truncation=False,
556
- max_length=self.max_length,
557
- padding="longest",
558
- return_tensors="pt",
559
- )
560
- input_ids = batch_encoding["input_ids"]
561
- tokens_list = self._split_chunks(input_ids)
562
- embeds_list = []
563
- for tokens in tokens_list:
564
- tokens = tokens.to(self.device)
565
- outputs = self.clip_text_model(input_ids=tokens, **kwargs)
566
- # use last_hidden_state as in the old version
567
- embeds_list.append(outputs.last_hidden_state)
568
- return torch.cat(embeds_list, dim=1)
569
-
570
- # New helper for splitting tokens (mimicking split_chunks behavior)
571
- def _split_chunks(self, input_ids, chunk_size=75):
572
- tokens_list = []
573
- bs, n = input_ids.shape
574
- id_start = input_ids[:, 0].unsqueeze(1)
575
- id_end = input_ids[:, -1].unsqueeze(1)
576
- if n == 2: # empty caption
577
- tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1))
578
- return tokens_list
579
-
580
- trimmed = input_ids[:, 1:-1]
581
- num_full = (n - 2) // chunk_size
582
- for i in range(num_full):
583
- group = trimmed[:, i*chunk_size:(i+1)*chunk_size]
584
- group_pad = torch.cat((id_start, group, id_end), dim=1)
585
- tokens_list.append(group_pad)
586
- rem = (n - 2) % chunk_size
587
- if rem > 0:
588
- group = trimmed[:, -rem:]
589
- pad_cols = chunk_size - group.shape[1]
590
- padding = id_end.expand(bs, pad_cols)
591
- group_pad = torch.cat((id_start, group, padding, id_end), dim=1)
592
- tokens_list.append(group_pad)
593
- return tokens_list
594
-
595
  @torch.no_grad()
596
  def forward(
597
  self,
@@ -704,9 +654,10 @@ class TextEmbeddingModule(nn.Module):
704
  # hint = self.arr2tensor(np_hint, len(prompt))
705
 
706
  self.embedding_manager.encode_text(text_info)
707
- prompt_embeds = self._encode_text([prompt], embedding_manager=self.embedding_manager)
 
708
  self.embedding_manager.encode_text(text_info)
709
- negative_prompt_embeds = self._encode_text(
710
  [negative_prompt or ""], embedding_manager=self.embedding_manager
711
  )
712
 
@@ -856,11 +807,10 @@ class TextEmbeddingModule(nn.Module):
856
  return new_string[:-nSpace]
857
 
858
  def to(self, *args, **kwargs):
859
- self.clip_text_model = self.clip_text_model.to(*args, **kwargs)
860
- self.device = self.clip_text_model.device
861
  self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
862
  self.text_predictor = self.text_predictor.to(*args, **kwargs)
863
- self.device = self.clip_text_model.device
864
  return self
865
 
866
 
 
35
  import torch
36
  import torch.nn.functional as F
37
  from easydict import EasyDict as edict
38
+ from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
39
  from huggingface_hub import hf_hub_download
40
  from ocr_recog.RecModel import RecModel
41
  from PIL import Image, ImageDraw, ImageFont
 
207
  class EmbeddingManager(nn.Module):
208
  def __init__(
209
  self,
210
+ embedder,
211
  placeholder_string="*",
212
  use_fp16=False,
 
213
  ):
214
  super().__init__()
215
+ get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
216
  token_dim = 768
217
  self.get_recog_emb = None
218
  self.token_dim = token_dim
 
223
  filename="text_embedding_module/proj.safetensors",
224
  cache_dir=HF_MODULES_CACHE,
225
  )
226
+ self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))
227
  if use_fp16:
228
  self.proj = self.proj.to(dtype=torch.float16)
229
 
 
526
  self.font = ImageFont.truetype(font_path, 60)
527
  self.use_fp16 = use_fp16
528
  self.device = device
529
+ self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
530
+ self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
 
 
 
 
 
 
 
531
  rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
532
  self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
533
  args = {}
534
  args["rec_image_shape"] = "3, 48, 320"
535
  args["rec_batch_num"] = 6
536
+ args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt"
537
  args["rec_char_dict_path"] = hf_hub_download(
538
  repo_id="tolgacangoz/anytext",
539
  filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
 
542
  args["use_fp16"] = use_fp16
543
  self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  @torch.no_grad()
546
  def forward(
547
  self,
 
654
  # hint = self.arr2tensor(np_hint, len(prompt))
655
 
656
  self.embedding_manager.encode_text(text_info)
657
+ prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager)
658
+
659
  self.embedding_manager.encode_text(text_info)
660
+ negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode(
661
  [negative_prompt or ""], embedding_manager=self.embedding_manager
662
  )
663
 
 
807
  return new_string[:-nSpace]
808
 
809
  def to(self, *args, **kwargs):
810
+ self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
 
811
  self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
812
  self.text_predictor = self.text_predictor.to(*args, **kwargs)
813
+ self.device = self.frozen_CLIP_embedder_t3.device
814
  return self
815
 
816