Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
f4ebffc
·
verified ·
1 Parent(s): 6f8882f

Upload anytext.py

Browse files
Files changed (1) hide show
  1. anytext.py +12 -2
anytext.py CHANGED
@@ -29,6 +29,7 @@ from functools import partial
29
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
30
 
31
  import cv2
 
32
  import numpy as np
33
  import PIL.Image
34
  import torch
@@ -154,7 +155,12 @@ class EmbeddingManager(nn.Module):
154
  self.token_dim = token_dim
155
 
156
  self.proj = nn.Linear(40 * 64, token_dim)
157
- # self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device)))
 
 
 
 
 
158
  if use_fp16:
159
  self.proj = self.proj.to(dtype=torch.float16)
160
 
@@ -499,7 +505,10 @@ class TextEmbeddingModule(nn.Module):
499
  # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
500
  if draw_pos is None:
501
  pos_imgs = np.zeros((w, h, 1))
502
- if isinstance(draw_pos, str):
 
 
 
503
  draw_pos = cv2.imread(draw_pos)[..., ::-1]
504
  if draw_pos is None:
505
  raise ValueError(f"Can't read draw_pos image from {draw_pos}!")
@@ -981,6 +990,7 @@ class AnyTextPipeline(
981
  scheduler: KarrasDiffusionSchedulers,
982
  safety_checker: StableDiffusionSafetyChecker,
983
  feature_extractor: CLIPImageProcessor,
 
984
  text_embedding_module: TextEmbeddingModule = None,
985
  auxiliary_latent_module: AuxiliaryLatentModule = None,
986
  image_encoder: CLIPVisionModelWithProjection = None,
 
29
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
30
 
31
  import cv2
32
+ import huggingface_hub
33
  import numpy as np
34
  import PIL.Image
35
  import torch
 
155
  self.token_dim = token_dim
156
 
157
  self.proj = nn.Linear(40 * 64, token_dim)
158
+ proj_dir = hf_hub_download(
159
+ repo_id="tolgacangoz/anytext",
160
+ filename="text_embedding_module/proj.safetensors",
161
+ cache_dir=HF_MODULES_CACHE
162
+ )
163
+ self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))
164
  if use_fp16:
165
  self.proj = self.proj.to(dtype=torch.float16)
166
 
 
505
  # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
506
  if draw_pos is None:
507
  pos_imgs = np.zeros((w, h, 1))
508
+ if isinstance(draw_pos, PIL.Image.Image):
509
+ pos_imgs = np.array(draw_pos)[..., ::-1]
510
+ pos_imgs = 255 - pos_imgs
511
+ elif isinstance(draw_pos, str):
512
  draw_pos = cv2.imread(draw_pos)[..., ::-1]
513
  if draw_pos is None:
514
  raise ValueError(f"Can't read draw_pos image from {draw_pos}!")
 
990
  scheduler: KarrasDiffusionSchedulers,
991
  safety_checker: StableDiffusionSafetyChecker,
992
  feature_extractor: CLIPImageProcessor,
993
+ trust_remote_code: bool = False,
994
  text_embedding_module: TextEmbeddingModule = None,
995
  auxiliary_latent_module: AuxiliaryLatentModule = None,
996
  image_encoder: CLIPVisionModelWithProjection = None,