Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
04fc5b3
·
verified ·
1 Parent(s): 8d57b7c

Upload anytext.py

Browse files
Files changed (1) hide show
  1. anytext.py +17 -5
anytext.py CHANGED
@@ -69,6 +69,7 @@ from diffusers.utils import (
69
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
70
  from diffusers.configuration_utils import register_to_config, ConfigMixin
71
  from diffusers.models.modeling_utils import ModelMixin
 
72
 
73
 
74
  checker = BasicTokenizer()
@@ -269,9 +270,20 @@ def crop_image(src_img, mask):
269
 
270
 
271
  def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
272
- model_file_path = model_dir
273
- if model_file_path is not None and not os.path.exists(model_file_path):
274
- raise ValueError("not find model file path {}".format(model_file_path))
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  if model_lang == "ch":
277
  n_class = 6625
@@ -287,8 +299,8 @@ def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=Fal
287
  )
288
 
289
  rec_model = RecModel(rec_config)
290
- if model_file_path is not None:
291
- rec_model.load_state_dict(torch.load(model_file_path, map_location=device))
292
  return rec_model
293
 
294
 
 
69
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
70
  from diffusers.configuration_utils import register_to_config, ConfigMixin
71
  from diffusers.models.modeling_utils import ModelMixin
72
+ from huggingface_hub import hf_hub_download
73
 
74
 
75
  checker = BasicTokenizer()
 
270
 
271
 
272
  def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
273
+ if model_dir is None or not os.path.exists(model_dir):
274
+ try:
275
+ # Use the repo id from which the pipeline was loaded
276
+ model_dir = hf_hub_download(
277
+ repo_id="tolgacangoz/anytext",
278
+ filename="text_embedding_module/OCR/ppv3_rec.pth",
279
+ local_dir=".cache/diffusers",
280
+ local_dir_use_symlinks=True
281
+ )
282
+ except Exception as e:
283
+ raise ValueError(f"Could not download the model file: {e}")
284
+
285
+ if model_dir is not None and not os.path.exists(model_dir):
286
+ raise ValueError("not find model file path {}".format(model_dir))
287
 
288
  if model_lang == "ch":
289
  n_class = 6625
 
299
  )
300
 
301
  rec_model = RecModel(rec_config)
302
+ state_dict = torch.load(model_dir, map_location=device)
303
+ rec_model.load_state_dict(state_dict)
304
  return rec_model
305
 
306