Upload anytext.py
Browse files- anytext.py +47 -77
anytext.py
CHANGED
@@ -41,8 +41,10 @@ from safetensors.torch import load_file
|
|
41 |
from skimage.transform._geometric import _umeyama as get_sym_mat
|
42 |
from torch import nn
|
43 |
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
|
|
44 |
|
45 |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
|
|
46 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
47 |
from diffusers.loaders import (
|
48 |
FromSingleFileMixin,
|
@@ -52,13 +54,12 @@ from diffusers.loaders import (
|
|
52 |
)
|
53 |
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
|
54 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
|
|
55 |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
56 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
57 |
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
58 |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
59 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
60 |
-
from diffusers.configuration_utils import register_to_config, ConfigMixin
|
61 |
-
from diffusers.models.modeling_utils import ModelMixin
|
62 |
from diffusers.utils import (
|
63 |
USE_PEFT_BACKEND,
|
64 |
deprecate,
|
@@ -154,21 +155,14 @@ EXAMPLE_DOC_STRING = """
|
|
154 |
>>> # I chose a font file shared by an HF staff:
|
155 |
>>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
|
156 |
|
157 |
-
>>> # load control net and stable diffusion v1-5
|
158 |
>>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
|
159 |
... variant="fp16",)
|
160 |
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
|
161 |
... controlnet=anytext_controlnet, torch_dtype=torch.float16,
|
162 |
-
... trust_remote_code=
|
163 |
... ).to("cuda")
|
164 |
|
165 |
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
166 |
-
>>> # uncomment following line if PyTorch>=2.0 is not installed for memory optimization
|
167 |
-
>>> #pipe.enable_xformers_memory_efficient_attention()
|
168 |
-
|
169 |
-
>>> # uncomment following line if you want to offload the model to CPU for memory optimization
|
170 |
-
>>> # also remove the `.to("cuda")` part
|
171 |
-
>>> #pipe.enable_model_cpu_offload()
|
172 |
|
173 |
>>> # generate image
|
174 |
>>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
|
@@ -211,8 +205,8 @@ class EmbeddingManager(ModelMixin, ConfigMixin):
|
|
211 |
embedder,
|
212 |
placeholder_string="*",
|
213 |
use_fp16=False,
|
214 |
-
token_dim
|
215 |
-
get_recog_emb
|
216 |
):
|
217 |
super().__init__()
|
218 |
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
|
@@ -227,9 +221,7 @@ class EmbeddingManager(ModelMixin, ConfigMixin):
|
|
227 |
if use_fp16:
|
228 |
self.proj = self.proj.to(dtype=torch.float16)
|
229 |
|
230 |
-
# self.register_parameter("proj", proj)
|
231 |
self.placeholder_token = get_token_for_string(placeholder_string)
|
232 |
-
# self.register_config(placeholder_token=placeholder_token)
|
233 |
|
234 |
@torch.no_grad()
|
235 |
def encode_text(self, text_info):
|
@@ -350,12 +342,19 @@ def create_predictor(model_lang="ch", device="cpu", use_fp16=False):
|
|
350 |
n_class = 97
|
351 |
else:
|
352 |
raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
|
353 |
-
rec_config =
|
354 |
-
in_channels
|
355 |
-
backbone
|
356 |
-
neck
|
357 |
-
|
358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
rec_model = RecModel(rec_config)
|
361 |
state_dict = torch.load(model_dir, map_location=device)
|
@@ -521,12 +520,6 @@ class TextRecognizer(object):
|
|
521 |
return loss
|
522 |
|
523 |
|
524 |
-
import torch
|
525 |
-
from torch import nn
|
526 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
527 |
-
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
528 |
-
|
529 |
-
|
530 |
class AbstractEncoder(nn.Module):
|
531 |
def __init__(self):
|
532 |
super().__init__()
|
@@ -537,6 +530,7 @@ class AbstractEncoder(nn.Module):
|
|
537 |
|
538 |
class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
|
539 |
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
|
|
540 |
@register_to_config
|
541 |
def __init__(
|
542 |
self,
|
@@ -548,11 +542,13 @@ class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
|
|
548 |
):
|
549 |
super().__init__()
|
550 |
self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
|
551 |
-
self.transformer = CLIPTextModel.from_pretrained(
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
|
|
|
|
556 |
if freeze:
|
557 |
self.freeze()
|
558 |
|
@@ -731,11 +727,6 @@ class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
|
|
731 |
tokens_list.append(remaining_group_pad)
|
732 |
return tokens_list
|
733 |
|
734 |
-
# def to(self, *args, **kwargs):
|
735 |
-
# self.transformer = self.transformer.to(*args, **kwargs)
|
736 |
-
# self.device = self.transformer.device
|
737 |
-
# return self
|
738 |
-
|
739 |
|
740 |
class TextEmbeddingModule(ModelMixin, ConfigMixin):
|
741 |
@register_to_config
|
@@ -743,25 +734,21 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
|
|
743 |
super().__init__()
|
744 |
font = ImageFont.truetype(font_path, 60)
|
745 |
|
746 |
-
# self.use_fp16 = use_fp16
|
747 |
-
# self.device = device
|
748 |
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
|
749 |
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
|
750 |
self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()
|
751 |
-
args = {
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
|
|
|
|
759 |
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
|
760 |
|
761 |
-
# self.register_modules(
|
762 |
-
# frozen_CLIP_embedder_t3=frozen_CLIP_embedder_t3,
|
763 |
-
# embedding_manager=embedding_manager,
|
764 |
-
# )
|
765 |
self.register_to_config(font=font)
|
766 |
|
767 |
@torch.no_grad()
|
@@ -873,8 +860,6 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
|
|
873 |
text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)]
|
874 |
text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)]
|
875 |
|
876 |
-
# hint = self.arr2tensor(np_hint, len(prompt))
|
877 |
-
|
878 |
self.embedding_manager.encode_text(text_info)
|
879 |
prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager)
|
880 |
|
@@ -1028,11 +1013,6 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
|
|
1028 |
new_string += char + " " * nSpace
|
1029 |
return new_string[:-nSpace]
|
1030 |
|
1031 |
-
# def to(self, *args, **kwargs):
|
1032 |
-
# self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
|
1033 |
-
# self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
|
1034 |
-
# return self
|
1035 |
-
|
1036 |
|
1037 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
1038 |
def retrieve_latents(
|
@@ -1052,13 +1032,10 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
|
|
1052 |
@register_to_config
|
1053 |
def __init__(
|
1054 |
self,
|
1055 |
-
# font_path,
|
1056 |
vae,
|
1057 |
device="cpu",
|
1058 |
):
|
1059 |
super().__init__()
|
1060 |
-
# self.font = ImageFont.truetype(font_path, 60)
|
1061 |
-
# self.vae = vae.eval() if vae is not None else None
|
1062 |
|
1063 |
@torch.no_grad()
|
1064 |
def forward(
|
@@ -1100,7 +1077,9 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
|
|
1100 |
masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
|
1101 |
if dtype == torch.float16:
|
1102 |
masked_img = masked_img.half()
|
1103 |
-
masked_x = (
|
|
|
|
|
1104 |
if dtype == torch.float16:
|
1105 |
masked_x = masked_x.half()
|
1106 |
text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
|
@@ -1140,11 +1119,6 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
|
|
1140 |
new_string += char + " " * nSpace
|
1141 |
return new_string[:-nSpace]
|
1142 |
|
1143 |
-
# def to(self, *args, **kwargs):
|
1144 |
-
# self.vae = self.vae.to(*args, **kwargs)
|
1145 |
-
# self.device = self.vae.device
|
1146 |
-
# return self
|
1147 |
-
|
1148 |
|
1149 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
1150 |
def retrieve_timesteps(
|
@@ -1266,7 +1240,7 @@ class AnyTextPipeline(
|
|
1266 |
scheduler: KarrasDiffusionSchedulers,
|
1267 |
safety_checker: StableDiffusionSafetyChecker,
|
1268 |
feature_extractor: CLIPImageProcessor,
|
1269 |
-
font_path: str =
|
1270 |
text_embedding_module: Optional[TextEmbeddingModule] = None,
|
1271 |
auxiliary_latent_module: Optional[AuxiliaryLatentModule] = None,
|
1272 |
trust_remote_code: bool = False,
|
@@ -1274,15 +1248,11 @@ class AnyTextPipeline(
|
|
1274 |
requires_safety_checker: bool = True,
|
1275 |
):
|
1276 |
super().__init__()
|
1277 |
-
|
1278 |
-
font_path
|
1279 |
-
|
1280 |
-
)
|
1281 |
-
auxiliary_latent_module = AuxiliaryLatentModule(
|
1282 |
-
# font_path=font_path,
|
1283 |
-
vae=vae,
|
1284 |
-
# use_fp16=unet.dtype == torch.float16,
|
1285 |
-
)
|
1286 |
|
1287 |
if safety_checker is None and requires_safety_checker:
|
1288 |
logger.warning(
|
@@ -1321,7 +1291,7 @@ class AnyTextPipeline(
|
|
1321 |
self.control_image_processor = VaeImageProcessor(
|
1322 |
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
1323 |
)
|
1324 |
-
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
1325 |
|
1326 |
def modify_prompt(self, prompt):
|
1327 |
prompt = prompt.replace("“", '"')
|
|
|
41 |
from skimage.transform._geometric import _umeyama as get_sym_mat
|
42 |
from torch import nn
|
43 |
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
44 |
+
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
45 |
|
46 |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
47 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
48 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
49 |
from diffusers.loaders import (
|
50 |
FromSingleFileMixin,
|
|
|
54 |
)
|
55 |
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
|
56 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
57 |
+
from diffusers.models.modeling_utils import ModelMixin
|
58 |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
59 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
60 |
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
61 |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
62 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
|
|
|
|
63 |
from diffusers.utils import (
|
64 |
USE_PEFT_BACKEND,
|
65 |
deprecate,
|
|
|
155 |
>>> # I chose a font file shared by an HF staff:
|
156 |
>>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
|
157 |
|
|
|
158 |
>>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
|
159 |
... variant="fp16",)
|
160 |
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
|
161 |
... controlnet=anytext_controlnet, torch_dtype=torch.float16,
|
162 |
+
... trust_remote_code=False, # One needs to give permission to run this pipeline's code
|
163 |
... ).to("cuda")
|
164 |
|
165 |
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
>>> # generate image
|
168 |
>>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
|
|
|
205 |
embedder,
|
206 |
placeholder_string="*",
|
207 |
use_fp16=False,
|
208 |
+
token_dim=768,
|
209 |
+
get_recog_emb=None,
|
210 |
):
|
211 |
super().__init__()
|
212 |
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
|
|
|
221 |
if use_fp16:
|
222 |
self.proj = self.proj.to(dtype=torch.float16)
|
223 |
|
|
|
224 |
self.placeholder_token = get_token_for_string(placeholder_string)
|
|
|
225 |
|
226 |
@torch.no_grad()
|
227 |
def encode_text(self, text_info):
|
|
|
342 |
n_class = 97
|
343 |
else:
|
344 |
raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
|
345 |
+
rec_config = {
|
346 |
+
"in_channels": 3,
|
347 |
+
"backbone": {"type": "MobileNetV1Enhance", "scale": 0.5, "last_conv_stride": [1, 2], "last_pool_type": "avg"},
|
348 |
+
"neck": {
|
349 |
+
"type": "SequenceEncoder",
|
350 |
+
"encoder_type": "svtr",
|
351 |
+
"dims": 64,
|
352 |
+
"depth": 2,
|
353 |
+
"hidden_dims": 120,
|
354 |
+
"use_guide": True,
|
355 |
+
},
|
356 |
+
"head": {"type": "CTCHead", "fc_decay": 0.00001, "out_channels": n_class, "return_feats": True},
|
357 |
+
}
|
358 |
|
359 |
rec_model = RecModel(rec_config)
|
360 |
state_dict = torch.load(model_dir, map_location=device)
|
|
|
520 |
return loss
|
521 |
|
522 |
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
class AbstractEncoder(nn.Module):
|
524 |
def __init__(self):
|
525 |
super().__init__()
|
|
|
530 |
|
531 |
class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
|
532 |
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
533 |
+
|
534 |
@register_to_config
|
535 |
def __init__(
|
536 |
self,
|
|
|
542 |
):
|
543 |
super().__init__()
|
544 |
self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
|
545 |
+
self.transformer = CLIPTextModel.from_pretrained(
|
546 |
+
"tolgacangoz/anytext",
|
547 |
+
subfolder="text_encoder",
|
548 |
+
torch_dtype=torch.float16 if use_fp16 else torch.float32,
|
549 |
+
variant="fp16" if use_fp16 else None,
|
550 |
+
)
|
551 |
+
|
552 |
if freeze:
|
553 |
self.freeze()
|
554 |
|
|
|
727 |
tokens_list.append(remaining_group_pad)
|
728 |
return tokens_list
|
729 |
|
|
|
|
|
|
|
|
|
|
|
730 |
|
731 |
class TextEmbeddingModule(ModelMixin, ConfigMixin):
|
732 |
@register_to_config
|
|
|
734 |
super().__init__()
|
735 |
font = ImageFont.truetype(font_path, 60)
|
736 |
|
|
|
|
|
737 |
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
|
738 |
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
|
739 |
self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()
|
740 |
+
args = {
|
741 |
+
"rec_image_shape": "3, 48, 320",
|
742 |
+
"rec_batch_num": 6,
|
743 |
+
"rec_char_dict_path": hf_hub_download(
|
744 |
+
repo_id="tolgacangoz/anytext",
|
745 |
+
filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
|
746 |
+
cache_dir=HF_MODULES_CACHE,
|
747 |
+
),
|
748 |
+
"use_fp16": use_fp16,
|
749 |
+
}
|
750 |
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
|
751 |
|
|
|
|
|
|
|
|
|
752 |
self.register_to_config(font=font)
|
753 |
|
754 |
@torch.no_grad()
|
|
|
860 |
text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)]
|
861 |
text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)]
|
862 |
|
|
|
|
|
863 |
self.embedding_manager.encode_text(text_info)
|
864 |
prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager)
|
865 |
|
|
|
1013 |
new_string += char + " " * nSpace
|
1014 |
return new_string[:-nSpace]
|
1015 |
|
|
|
|
|
|
|
|
|
|
|
1016 |
|
1017 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
1018 |
def retrieve_latents(
|
|
|
1032 |
@register_to_config
|
1033 |
def __init__(
|
1034 |
self,
|
|
|
1035 |
vae,
|
1036 |
device="cpu",
|
1037 |
):
|
1038 |
super().__init__()
|
|
|
|
|
1039 |
|
1040 |
@torch.no_grad()
|
1041 |
def forward(
|
|
|
1077 |
masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
|
1078 |
if dtype == torch.float16:
|
1079 |
masked_img = masked_img.half()
|
1080 |
+
masked_x = (
|
1081 |
+
retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor
|
1082 |
+
).detach()
|
1083 |
if dtype == torch.float16:
|
1084 |
masked_x = masked_x.half()
|
1085 |
text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
|
|
|
1119 |
new_string += char + " " * nSpace
|
1120 |
return new_string[:-nSpace]
|
1121 |
|
|
|
|
|
|
|
|
|
|
|
1122 |
|
1123 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
1124 |
def retrieve_timesteps(
|
|
|
1240 |
scheduler: KarrasDiffusionSchedulers,
|
1241 |
safety_checker: StableDiffusionSafetyChecker,
|
1242 |
feature_extractor: CLIPImageProcessor,
|
1243 |
+
font_path: str = None,
|
1244 |
text_embedding_module: Optional[TextEmbeddingModule] = None,
|
1245 |
auxiliary_latent_module: Optional[AuxiliaryLatentModule] = None,
|
1246 |
trust_remote_code: bool = False,
|
|
|
1248 |
requires_safety_checker: bool = True,
|
1249 |
):
|
1250 |
super().__init__()
|
1251 |
+
if font_path is None:
|
1252 |
+
raise ValueError("font_path is required!")
|
1253 |
+
|
1254 |
+
text_embedding_module = TextEmbeddingModule(font_path=font_path, use_fp16=unet.dtype == torch.float16)
|
1255 |
+
auxiliary_latent_module = AuxiliaryLatentModule(vae=vae)
|
|
|
|
|
|
|
|
|
1256 |
|
1257 |
if safety_checker is None and requires_safety_checker:
|
1258 |
logger.warning(
|
|
|
1291 |
self.control_image_processor = VaeImageProcessor(
|
1292 |
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
1293 |
)
|
1294 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
1295 |
|
1296 |
def modify_prompt(self, prompt):
|
1297 |
prompt = prompt.replace("“", '"')
|