Upload anytext.py
Browse files- anytext.py +12 -2
anytext.py
CHANGED
@@ -29,6 +29,7 @@ from functools import partial
|
|
29 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
30 |
|
31 |
import cv2
|
|
|
32 |
import numpy as np
|
33 |
import PIL.Image
|
34 |
import torch
|
@@ -154,7 +155,12 @@ class EmbeddingManager(nn.Module):
|
|
154 |
self.token_dim = token_dim
|
155 |
|
156 |
self.proj = nn.Linear(40 * 64, token_dim)
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
158 |
if use_fp16:
|
159 |
self.proj = self.proj.to(dtype=torch.float16)
|
160 |
|
@@ -499,7 +505,10 @@ class TextEmbeddingModule(nn.Module):
|
|
499 |
# preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
|
500 |
if draw_pos is None:
|
501 |
pos_imgs = np.zeros((w, h, 1))
|
502 |
-
if isinstance(draw_pos,
|
|
|
|
|
|
|
503 |
draw_pos = cv2.imread(draw_pos)[..., ::-1]
|
504 |
if draw_pos is None:
|
505 |
raise ValueError(f"Can't read draw_pos image from {draw_pos}!")
|
@@ -981,6 +990,7 @@ class AnyTextPipeline(
|
|
981 |
scheduler: KarrasDiffusionSchedulers,
|
982 |
safety_checker: StableDiffusionSafetyChecker,
|
983 |
feature_extractor: CLIPImageProcessor,
|
|
|
984 |
text_embedding_module: TextEmbeddingModule = None,
|
985 |
auxiliary_latent_module: AuxiliaryLatentModule = None,
|
986 |
image_encoder: CLIPVisionModelWithProjection = None,
|
|
|
29 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
30 |
|
31 |
import cv2
|
32 |
+
import huggingface_hub
|
33 |
import numpy as np
|
34 |
import PIL.Image
|
35 |
import torch
|
|
|
155 |
self.token_dim = token_dim
|
156 |
|
157 |
self.proj = nn.Linear(40 * 64, token_dim)
|
158 |
+
proj_dir = hf_hub_download(
|
159 |
+
repo_id="tolgacangoz/anytext",
|
160 |
+
filename="text_embedding_module/proj.safetensors",
|
161 |
+
cache_dir=HF_MODULES_CACHE
|
162 |
+
)
|
163 |
+
self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))
|
164 |
if use_fp16:
|
165 |
self.proj = self.proj.to(dtype=torch.float16)
|
166 |
|
|
|
505 |
# preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
|
506 |
if draw_pos is None:
|
507 |
pos_imgs = np.zeros((w, h, 1))
|
508 |
+
if isinstance(draw_pos, PIL.Image.Image):
|
509 |
+
pos_imgs = np.array(draw_pos)[..., ::-1]
|
510 |
+
pos_imgs = 255 - pos_imgs
|
511 |
+
elif isinstance(draw_pos, str):
|
512 |
draw_pos = cv2.imread(draw_pos)[..., ::-1]
|
513 |
if draw_pos is None:
|
514 |
raise ValueError(f"Can't read draw_pos image from {draw_pos}!")
|
|
|
990 |
scheduler: KarrasDiffusionSchedulers,
|
991 |
safety_checker: StableDiffusionSafetyChecker,
|
992 |
feature_extractor: CLIPImageProcessor,
|
993 |
+
trust_remote_code: bool = False,
|
994 |
text_embedding_module: TextEmbeddingModule = None,
|
995 |
auxiliary_latent_module: AuxiliaryLatentModule = None,
|
996 |
image_encoder: CLIPVisionModelWithProjection = None,
|