Spaces:
Sleeping
Sleeping
File size: 4,565 Bytes
99f3ba3 92c11dc 6b1f545 99f3ba3 6b1f545 99f3ba3 6b1f545 99f3ba3 6b1f545 99f3ba3 6b1f545 99f3ba3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from uuid import uuid4
import gradio as gr
from laia.scripts.htr.decode_ctc import run as decode
from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs
import sys
from tempfile import NamedTemporaryFile, mkdtemp
from pathlib import Path
from contextlib import redirect_stdout
import re
from huggingface_hub import snapshot_download
images = Path(mkdtemp())
IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})"
CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line
TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")
models_name = ["Teklia/pylaia-rimes", "Teklia/pylaia-belfort", "Teklia/pylaia-casia-hwdb2", "Teklia/pylaia-esposalles", "Teklia/pylaia-himanis", "Teklia/pylaia-home-alcar", "Teklia/pylaia-iam", "Teklia/pylaia-newseye-austrian", "Teklia/pylaia-norhand-v1", "Teklia/pylaia-norhand-v2", "Teklia/pylaia-norhand-v3", "Teklia/pylaia-popp", "Teklia/pylaia-rimes", "Teklia/pylaia-PELLET-CasimirMarius"]
]
MODELS = {}
DEFAULT_HEIGHT = 128
def get_width(image, height=DEFAULT_HEIGHT):
aspect_ratio = image.width / image.height
return height * aspect_ratio
def load_model(model_name):
if model_name not in MODELS:
MODELS[model_name] = Path(snapshot_download(model_name))
return MODELS[model_name]
def predict(model_name, input_img):
model_dir = load_model(model_name)
temperature = 2.0
batch_size = 1
weights_path = model_dir / "weights.ckpt"
syms_path = model_dir / "syms.txt"
language_model_params = {"language_model_weight": 1.0}
use_language_model = (model_dir / "tokens.txt").exists()
if use_language_model:
language_model_params.update(
{
"language_model_path": str(model_dir / "language_model.arpa.gz"),
"lexicon_path": str(model_dir / "lexicon.txt"),
"tokens_path": str(model_dir / "tokens.txt"),
}
)
common_args = CommonArgs(
checkpoint=str(weights_path.relative_to(model_dir)),
train_path=str(model_dir),
experiment_dirname="",
)
data_args = DataArgs(batch_size=batch_size, color_mode="L")
trainer_args = TrainerArgs(
# Disable progress bar else it messes with frontend display
progress_bar_refresh_rate=0
)
decode_args = DecodeArgs(
include_img_ids=True,
join_string="",
convert_spaces=True,
print_line_confidence_scores=True,
print_word_confidence_scores=False,
temperature=temperature,
use_language_model=use_language_model,
**language_model_params,
)
with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list:
image_id = uuid4()
# Resize image to 128 if bigger/smaller
input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT))
input_img.save(str(images / f"{image_id}.jpg"))
# Export image list
Path(img_list.name).write_text("\n".join([str(image_id)]))
# Capture stdout as that's where PyLaia outputs predictions
with redirect_stdout(open(pred_stdout.name, mode="w")):
decode(
syms=str(syms_path),
img_list=img_list.name,
img_dirs=[str(images)],
common=common_args,
data=data_args,
trainer=trainer_args,
decode=decode_args,
num_workers=1,
)
# Flush stdout to avoid output buffering
sys.stdout.flush()
predictions = Path(pred_stdout.name).read_text().strip().splitlines()
assert len(predictions) == 1
_, score, text = LINE_PREDICTION.match(predictions[0]).groups()
return input_img, {"text": text, "score": score}
gradio_app = gr.Interface(
predict,
inputs=[
gr.Dropdown(models_name, value=models_name[0], label="Models"),
gr.Image(
label="Upload an image of a line",
sources=["upload", "clipboard"],
type="pil",
height=DEFAULT_HEIGHT,
width=2000,
image_mode="L",
),
],
outputs=[
gr.Image(label="Processed Image"),
gr.JSON(label="Decoded text"),
],
examples=[
["Teklia/pylaia-rimes", str(filename)]
for filename in Path("examples").iterdir()
],
title="Decode the transcription of an image using a PyLaia model",
cache_examples=True,
)
if __name__ == "__main__":
gradio_app.launch()
|