Upload anytext.py
Browse files- anytext.py +17 -5
anytext.py
CHANGED
@@ -69,6 +69,7 @@ from diffusers.utils import (
|
|
69 |
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
70 |
from diffusers.configuration_utils import register_to_config, ConfigMixin
|
71 |
from diffusers.models.modeling_utils import ModelMixin
|
|
|
72 |
|
73 |
|
74 |
checker = BasicTokenizer()
|
@@ -269,9 +270,20 @@ def crop_image(src_img, mask):
|
|
269 |
|
270 |
|
271 |
def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
if model_lang == "ch":
|
277 |
n_class = 6625
|
@@ -287,8 +299,8 @@ def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=Fal
|
|
287 |
)
|
288 |
|
289 |
rec_model = RecModel(rec_config)
|
290 |
-
|
291 |
-
|
292 |
return rec_model
|
293 |
|
294 |
|
|
|
69 |
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
70 |
from diffusers.configuration_utils import register_to_config, ConfigMixin
|
71 |
from diffusers.models.modeling_utils import ModelMixin
|
72 |
+
from huggingface_hub import hf_hub_download
|
73 |
|
74 |
|
75 |
checker = BasicTokenizer()
|
|
|
270 |
|
271 |
|
272 |
def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
|
273 |
+
if model_dir is None or not os.path.exists(model_dir):
|
274 |
+
try:
|
275 |
+
# Use the repo id from which the pipeline was loaded
|
276 |
+
model_dir = hf_hub_download(
|
277 |
+
repo_id="tolgacangoz/anytext",
|
278 |
+
filename="text_embedding_module/OCR/ppv3_rec.pth",
|
279 |
+
local_dir=".cache/diffusers",
|
280 |
+
local_dir_use_symlinks=True
|
281 |
+
)
|
282 |
+
except Exception as e:
|
283 |
+
raise ValueError(f"Could not download the model file: {e}")
|
284 |
+
|
285 |
+
if model_dir is not None and not os.path.exists(model_dir):
|
286 |
+
raise ValueError("not find model file path {}".format(model_dir))
|
287 |
|
288 |
if model_lang == "ch":
|
289 |
n_class = 6625
|
|
|
299 |
)
|
300 |
|
301 |
rec_model = RecModel(rec_config)
|
302 |
+
state_dict = torch.load(model_dir, map_location=device)
|
303 |
+
rec_model.load_state_dict(state_dict)
|
304 |
return rec_model
|
305 |
|
306 |
|