Spaces:
Running
on
T4
Running
on
T4
refactored the pipeline
Browse files- app.py +4 -4
- helper/gradio_config.py +3 -0
- requirements.txt +2 -0
- src/htr_pipeline/gradio_backend.py +28 -7
- src/htr_pipeline/pipeline.py +10 -6
- src/htr_pipeline/utils/helper.py +15 -0
- src/htr_pipeline/utils/parser_xml.py +0 -60
- src/htr_pipeline/utils/pipeline_inferencer.py +107 -0
- src/htr_pipeline/utils/process_xml.py +0 -167
- src/htr_pipeline/utils/visualize_xml.py +68 -0
- src/htr_pipeline/utils/xml_helper.py +55 -0
- tabs/htr_tool.py +5 -29
- tabs/stepwise_htr_tool.py +44 -58
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
from helper.gradio_config import css,
|
4 |
from helper.text.text_about import TextAbout
|
5 |
from helper.text.text_app import TextApp
|
6 |
from helper.text.text_howto import TextHowTo
|
@@ -21,7 +21,7 @@ with gr.Blocks(title="HTR Riksarkivet", theme=theme, css=css) as demo:
|
|
21 |
with gr.Tab("How to use"):
|
22 |
with gr.Tabs():
|
23 |
with gr.Tab("HTR Tool"):
|
24 |
-
with gr.Row(
|
25 |
with gr.Column():
|
26 |
gr.Markdown(TextHowTo.htr_tool)
|
27 |
with gr.Column():
|
@@ -33,7 +33,7 @@ with gr.Blocks(title="HTR Riksarkivet", theme=theme, css=css) as demo:
|
|
33 |
gr.Markdown(TextHowTo.reach_out)
|
34 |
|
35 |
with gr.Tab("Stepwise HTR Tool"):
|
36 |
-
with gr.Row(
|
37 |
gr.Markdown(TextHowTo.stepwise_htr_tool)
|
38 |
with gr.Row():
|
39 |
gr.Markdown(TextHowTo.stepwise_htr_tool_tab_intro)
|
@@ -115,7 +115,7 @@ print(job.result())
|
|
115 |
with gr.Column():
|
116 |
gr.Markdown(TextRoadmap.discussion)
|
117 |
|
118 |
-
demo.load(None, None, None, _js=js)
|
119 |
|
120 |
|
121 |
demo.queue(concurrency_count=1, max_size=1)
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
from helper.gradio_config import css, theme
|
4 |
from helper.text.text_about import TextAbout
|
5 |
from helper.text.text_app import TextApp
|
6 |
from helper.text.text_howto import TextHowTo
|
|
|
21 |
with gr.Tab("How to use"):
|
22 |
with gr.Tabs():
|
23 |
with gr.Tab("HTR Tool"):
|
24 |
+
with gr.Row(equal_height=False):
|
25 |
with gr.Column():
|
26 |
gr.Markdown(TextHowTo.htr_tool)
|
27 |
with gr.Column():
|
|
|
33 |
gr.Markdown(TextHowTo.reach_out)
|
34 |
|
35 |
with gr.Tab("Stepwise HTR Tool"):
|
36 |
+
with gr.Row(equal_height=False):
|
37 |
gr.Markdown(TextHowTo.stepwise_htr_tool)
|
38 |
with gr.Row():
|
39 |
gr.Markdown(TextHowTo.stepwise_htr_tool_tab_intro)
|
|
|
115 |
with gr.Column():
|
116 |
gr.Markdown(TextRoadmap.discussion)
|
117 |
|
118 |
+
# demo.load(None, None, None, _js=js)
|
119 |
|
120 |
|
121 |
demo.queue(concurrency_count=1, max_size=1)
|
helper/gradio_config.py
CHANGED
@@ -21,6 +21,9 @@ class GradioConfig:
|
|
21 |
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 450px}
|
22 |
#gallery {height: 400px}
|
23 |
.fixed-height.svelte-g4rw9.svelte-g4rw9 {min-height: 400px;}
|
|
|
|
|
|
|
24 |
"""
|
25 |
|
26 |
def generate_tooltip_css(self):
|
|
|
21 |
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 450px}
|
22 |
#gallery {height: 400px}
|
23 |
.fixed-height.svelte-g4rw9.svelte-g4rw9 {min-height: 400px;}
|
24 |
+
|
25 |
+
#gallery_lines > div.preview.svelte-1b19cri > div.thumbnails.scroll-hide.svelte-1b19cri {display: none;}
|
26 |
+
|
27 |
"""
|
28 |
|
29 |
def generate_tooltip_css(self):
|
requirements.txt
CHANGED
@@ -14,6 +14,8 @@ pillow==9.5.0
|
|
14 |
|
15 |
|
16 |
|
|
|
|
|
17 |
# make install_openmmlab (they are excuted in dockerfile)
|
18 |
# !pip install -U openmim
|
19 |
# !mim install mmengine
|
|
|
14 |
|
15 |
|
16 |
|
17 |
+
|
18 |
+
|
19 |
# make install_openmmlab (they are excuted in dockerfile)
|
20 |
# !pip install -U openmim
|
21 |
# !mim install mmengine
|
src/htr_pipeline/gradio_backend.py
CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
|
|
5 |
|
6 |
from src.htr_pipeline.inferencer import Inferencer, InferencerInterface
|
7 |
from src.htr_pipeline.pipeline import Pipeline, PipelineInterface
|
|
|
8 |
|
9 |
|
10 |
class SingletonModelLoader:
|
@@ -28,6 +29,7 @@ class FastTrack:
|
|
28 |
self.pipeline: PipelineInterface = model_loader.pipeline
|
29 |
|
30 |
def segment_to_xml(self, image, radio_button_choices):
|
|
|
31 |
xml_xml = "page_xml.xml"
|
32 |
xml_txt = "page_txt.txt"
|
33 |
|
@@ -40,6 +42,11 @@ class FastTrack:
|
|
40 |
f.write(rendered_xml)
|
41 |
|
42 |
xml_img = self.visualize_xml_and_return_txt(image, xml_txt)
|
|
|
|
|
|
|
|
|
|
|
43 |
if len(radio_button_choices) < 2:
|
44 |
if radio_button_choices[0] == "Txt":
|
45 |
returned_file_extension = xml_txt
|
@@ -47,8 +54,7 @@ class FastTrack:
|
|
47 |
returned_file_extension = xml_xml
|
48 |
else:
|
49 |
returned_file_extension = [xml_txt, xml_xml]
|
50 |
-
|
51 |
-
return xml_img, returned_file_extension, gr.update(visible=True)
|
52 |
|
53 |
def segment_to_xml_api(self, image):
|
54 |
rendered_xml = self.pipeline.running_htr_pipeline(image)
|
@@ -70,12 +76,14 @@ class CustomTrack:
|
|
70 |
def __init__(self, model_loader):
|
71 |
self.inferencer: InferencerInterface = model_loader.inferencer
|
72 |
|
|
|
73 |
def region_segment(self, image, pred_score_threshold, containments_treshold):
|
74 |
predicted_regions, regions_cropped_ordered, _, _ = self.inferencer.predict_regions(
|
75 |
image, pred_score_threshold, containments_treshold
|
76 |
)
|
77 |
return predicted_regions, regions_cropped_ordered, gr.update(visible=False), gr.update(visible=True)
|
78 |
|
|
|
79 |
def line_segment(self, image, pred_score_threshold, containments_threshold):
|
80 |
predicted_lines, lines_cropped_ordered, _ = self.inferencer.predict_lines(
|
81 |
image, pred_score_threshold, containments_threshold
|
@@ -93,22 +101,35 @@ class CustomTrack:
|
|
93 |
)
|
94 |
|
95 |
def transcribe_text(self, df, images):
|
|
|
96 |
transcription_temp_list_with_score = []
|
97 |
mapping_dict = {}
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
for image in images:
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
transcribed_text, prediction_score_from_htr = self.inferencer.transcribe(image)
|
101 |
transcription_temp_list_with_score.append((transcribed_text, prediction_score_from_htr))
|
102 |
|
103 |
df_trans_explore = pd.DataFrame(
|
104 |
-
transcription_temp_list_with_score, columns=["Transcribed text", "
|
105 |
)
|
106 |
|
107 |
mapping_dict[transcribed_text] = image
|
108 |
|
109 |
-
yield df_trans_explore[
|
110 |
-
|
111 |
-
|
112 |
|
113 |
def get_select_index_image(self, images_from_gallery, evt: gr.SelectData):
|
114 |
return images_from_gallery[evt.index]["name"]
|
@@ -120,7 +141,7 @@ class CustomTrack:
|
|
120 |
new_first = [sorted_image]
|
121 |
new_list = [img for txt, img in mapping_dict.items() if txt != key_text]
|
122 |
new_first.extend(new_list)
|
123 |
-
return new_first
|
124 |
|
125 |
def download_df_to_txt(self, transcribed_df):
|
126 |
text_in_list = transcribed_df["Transcribed text"].tolist()
|
|
|
5 |
|
6 |
from src.htr_pipeline.inferencer import Inferencer, InferencerInterface
|
7 |
from src.htr_pipeline.pipeline import Pipeline, PipelineInterface
|
8 |
+
from src.htr_pipeline.utils.helper import gradio_info
|
9 |
|
10 |
|
11 |
class SingletonModelLoader:
|
|
|
29 |
self.pipeline: PipelineInterface = model_loader.pipeline
|
30 |
|
31 |
def segment_to_xml(self, image, radio_button_choices):
|
32 |
+
gr.Info("Running HTR-pipeline")
|
33 |
xml_xml = "page_xml.xml"
|
34 |
xml_txt = "page_txt.txt"
|
35 |
|
|
|
42 |
f.write(rendered_xml)
|
43 |
|
44 |
xml_img = self.visualize_xml_and_return_txt(image, xml_txt)
|
45 |
+
returned_file_extension = self.file_extenstion_to_return(radio_button_choices, xml_xml, xml_txt)
|
46 |
+
|
47 |
+
return xml_img, returned_file_extension, gr.update(visible=True)
|
48 |
+
|
49 |
+
def file_extenstion_to_return(self, radio_button_choices, xml_xml, xml_txt):
|
50 |
if len(radio_button_choices) < 2:
|
51 |
if radio_button_choices[0] == "Txt":
|
52 |
returned_file_extension = xml_txt
|
|
|
54 |
returned_file_extension = xml_xml
|
55 |
else:
|
56 |
returned_file_extension = [xml_txt, xml_xml]
|
57 |
+
return returned_file_extension
|
|
|
58 |
|
59 |
def segment_to_xml_api(self, image):
|
60 |
rendered_xml = self.pipeline.running_htr_pipeline(image)
|
|
|
76 |
def __init__(self, model_loader):
|
77 |
self.inferencer: InferencerInterface = model_loader.inferencer
|
78 |
|
79 |
+
@gradio_info("Running Segment Region")
|
80 |
def region_segment(self, image, pred_score_threshold, containments_treshold):
|
81 |
predicted_regions, regions_cropped_ordered, _, _ = self.inferencer.predict_regions(
|
82 |
image, pred_score_threshold, containments_treshold
|
83 |
)
|
84 |
return predicted_regions, regions_cropped_ordered, gr.update(visible=False), gr.update(visible=True)
|
85 |
|
86 |
+
@gradio_info("Running Segment Line")
|
87 |
def line_segment(self, image, pred_score_threshold, containments_threshold):
|
88 |
predicted_lines, lines_cropped_ordered, _ = self.inferencer.predict_lines(
|
89 |
image, pred_score_threshold, containments_threshold
|
|
|
101 |
)
|
102 |
|
103 |
def transcribe_text(self, df, images):
|
104 |
+
gr.Info("Running Transcribe Lines")
|
105 |
transcription_temp_list_with_score = []
|
106 |
mapping_dict = {}
|
107 |
|
108 |
+
total_images = len(images)
|
109 |
+
current_index = 0
|
110 |
+
|
111 |
+
bool_to_show_placeholder = gr.update(visible=True)
|
112 |
+
bool_to_show_control_results_transcribe = gr.update(visible=False)
|
113 |
+
|
114 |
for image in images:
|
115 |
+
current_index += 1
|
116 |
+
|
117 |
+
if current_index == total_images:
|
118 |
+
bool_to_show_control_results_transcribe = gr.update(visible=True)
|
119 |
+
bool_to_show_placeholder = gr.update(visible=False)
|
120 |
+
|
121 |
transcribed_text, prediction_score_from_htr = self.inferencer.transcribe(image)
|
122 |
transcription_temp_list_with_score.append((transcribed_text, prediction_score_from_htr))
|
123 |
|
124 |
df_trans_explore = pd.DataFrame(
|
125 |
+
transcription_temp_list_with_score, columns=["Transcribed text", "Pred score"]
|
126 |
)
|
127 |
|
128 |
mapping_dict[transcribed_text] = image
|
129 |
|
130 |
+
yield df_trans_explore[
|
131 |
+
["Transcribed text"]
|
132 |
+
], df_trans_explore, mapping_dict, bool_to_show_control_results_transcribe, bool_to_show_placeholder
|
133 |
|
134 |
def get_select_index_image(self, images_from_gallery, evt: gr.SelectData):
|
135 |
return images_from_gallery[evt.index]["name"]
|
|
|
141 |
new_first = [sorted_image]
|
142 |
new_list = [img for txt, img in mapping_dict.items() if txt != key_text]
|
143 |
new_first.extend(new_list)
|
144 |
+
return new_first, key_text
|
145 |
|
146 |
def download_df_to_txt(self, transcribed_df):
|
147 |
text_in_list = transcribed_df["Transcribed text"].tolist()
|
src/htr_pipeline/pipeline.py
CHANGED
@@ -6,15 +6,18 @@ import numpy as np
|
|
6 |
from src.htr_pipeline.inferencer import Inferencer
|
7 |
from src.htr_pipeline.utils.helper import timer_func
|
8 |
from src.htr_pipeline.utils.parser_xml import XmlParser
|
|
|
9 |
from src.htr_pipeline.utils.preprocess_img import Preprocess
|
10 |
-
from src.htr_pipeline.utils.
|
|
|
|
|
11 |
|
12 |
|
13 |
class Pipeline:
|
14 |
def __init__(self, inferencer: Inferencer) -> None:
|
15 |
self.inferencer = inferencer
|
16 |
-
self.xml = XMLHelper()
|
17 |
self.preprocess_img = Preprocess()
|
|
|
18 |
|
19 |
@timer_func
|
20 |
def running_htr_pipeline(
|
@@ -27,7 +30,7 @@ class Pipeline:
|
|
27 |
input_image = self.preprocess_img.binarize_img(input_image)
|
28 |
image = mmcv.imread(input_image)
|
29 |
|
30 |
-
rendered_xml = self.
|
31 |
image, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold, self.inferencer
|
32 |
)
|
33 |
|
@@ -35,14 +38,15 @@ class Pipeline:
|
|
35 |
|
36 |
@timer_func
|
37 |
def visualize_xml(self, input_image: np.ndarray) -> np.ndarray:
|
38 |
-
|
39 |
bin_input_image = self.preprocess_img.binarize_img(input_image)
|
40 |
-
xml_image =
|
41 |
return xml_image
|
42 |
|
43 |
@timer_func
|
44 |
def parse_xml_to_txt(self) -> None:
|
45 |
-
|
|
|
46 |
|
47 |
|
48 |
class PipelineInterface(Protocol):
|
|
|
6 |
from src.htr_pipeline.inferencer import Inferencer
|
7 |
from src.htr_pipeline.utils.helper import timer_func
|
8 |
from src.htr_pipeline.utils.parser_xml import XmlParser
|
9 |
+
from src.htr_pipeline.utils.pipeline_inferencer import PipelineInferencer, XMLHelper
|
10 |
from src.htr_pipeline.utils.preprocess_img import Preprocess
|
11 |
+
from src.htr_pipeline.utils.process_segmask import SegMaskHelper
|
12 |
+
from src.htr_pipeline.utils.visualize_xml import XmlViz
|
13 |
+
from src.htr_pipeline.utils.xml_helper import XMLHelper
|
14 |
|
15 |
|
16 |
class Pipeline:
|
17 |
def __init__(self, inferencer: Inferencer) -> None:
|
18 |
self.inferencer = inferencer
|
|
|
19 |
self.preprocess_img = Preprocess()
|
20 |
+
self.pipeline_inferencer = PipelineInferencer(SegMaskHelper(), XMLHelper())
|
21 |
|
22 |
@timer_func
|
23 |
def running_htr_pipeline(
|
|
|
30 |
input_image = self.preprocess_img.binarize_img(input_image)
|
31 |
image = mmcv.imread(input_image)
|
32 |
|
33 |
+
rendered_xml = self.pipeline_inferencer.image_to_page_xml(
|
34 |
image, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold, self.inferencer
|
35 |
)
|
36 |
|
|
|
38 |
|
39 |
@timer_func
|
40 |
def visualize_xml(self, input_image: np.ndarray) -> np.ndarray:
|
41 |
+
xml_viz = XmlViz()
|
42 |
bin_input_image = self.preprocess_img.binarize_img(input_image)
|
43 |
+
xml_image = xml_viz.visualize_xml(bin_input_image)
|
44 |
return xml_image
|
45 |
|
46 |
@timer_func
|
47 |
def parse_xml_to_txt(self) -> None:
|
48 |
+
xml_visualizer_and_parser = XmlParser()
|
49 |
+
xml_visualizer_and_parser.xml_to_txt()
|
50 |
|
51 |
|
52 |
class PipelineInterface(Protocol):
|
src/htr_pipeline/utils/helper.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import functools
|
2 |
import threading
|
3 |
import time
|
|
|
4 |
|
|
|
5 |
import tqdm
|
6 |
|
7 |
|
@@ -75,6 +77,19 @@ def another_long_running_function(*args, **kwargs):
|
|
75 |
return "success"
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
if __name__ == "__main__":
|
79 |
# Basic example
|
80 |
retval = provide_progress_bar(long_running_function, estimated_time=5)
|
|
|
1 |
import functools
|
2 |
import threading
|
3 |
import time
|
4 |
+
from functools import wraps
|
5 |
|
6 |
+
import gradio as gr
|
7 |
import tqdm
|
8 |
|
9 |
|
|
|
77 |
return "success"
|
78 |
|
79 |
|
80 |
+
# Decorator for logging
|
81 |
+
def gradio_info(message):
|
82 |
+
def decorator(func):
|
83 |
+
@wraps(func)
|
84 |
+
def wrapper(*args, **kwargs):
|
85 |
+
gr.Info(message)
|
86 |
+
return func(*args, **kwargs)
|
87 |
+
|
88 |
+
return wrapper
|
89 |
+
|
90 |
+
return decorator
|
91 |
+
|
92 |
+
|
93 |
if __name__ == "__main__":
|
94 |
# Basic example
|
95 |
retval = provide_progress_bar(long_running_function, estimated_time=5)
|
src/htr_pipeline/utils/parser_xml.py
CHANGED
@@ -1,10 +1,5 @@
|
|
1 |
-
import math
|
2 |
-
import os
|
3 |
-
import random
|
4 |
import xml.etree.ElementTree as ET
|
5 |
|
6 |
-
from PIL import Image, ImageDraw, ImageFont
|
7 |
-
|
8 |
|
9 |
class XmlParser:
|
10 |
def __init__(self, page_xml="./page_xml.xml"):
|
@@ -12,61 +7,6 @@ class XmlParser:
|
|
12 |
self.root = self.tree.getroot()
|
13 |
self.namespace = "{http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15}"
|
14 |
|
15 |
-
def visualize_xml(
|
16 |
-
self,
|
17 |
-
background_image,
|
18 |
-
font_size=9,
|
19 |
-
text_offset=10,
|
20 |
-
font_path_tff="./src/htr_pipeline/utils/templates/arial.ttf",
|
21 |
-
):
|
22 |
-
image = Image.fromarray(background_image).convert("RGBA")
|
23 |
-
image_width = int(self.root.find(f"{self.namespace}Page").attrib["imageWidth"])
|
24 |
-
image_height = int(self.root.find(f"{self.namespace}Page").attrib["imageHeight"])
|
25 |
-
|
26 |
-
text_offset = -text_offset
|
27 |
-
base_font_size = font_size
|
28 |
-
font_path = font_path_tff
|
29 |
-
|
30 |
-
max_bbox_width = 0 # Initialize maximum bounding box width
|
31 |
-
|
32 |
-
for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
|
33 |
-
coords = textregion.find(f"{self.namespace}Coords").attrib["points"].split()
|
34 |
-
points = [tuple(map(int, point.split(","))) for point in coords]
|
35 |
-
x_coords, y_coords = zip(*points)
|
36 |
-
min_x, max_x = min(x_coords), max(x_coords)
|
37 |
-
bbox_width = max_x - min_x # Width of the current bounding box
|
38 |
-
max_bbox_width = max(max_bbox_width, bbox_width) # Update maximum bounding box width
|
39 |
-
|
40 |
-
scaling_factor = max_bbox_width / 400.0 # Use maximum bounding box width for scaling
|
41 |
-
font_size_scaled = int(base_font_size * scaling_factor)
|
42 |
-
font = ImageFont.truetype(font_path, font_size_scaled)
|
43 |
-
|
44 |
-
for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
|
45 |
-
fill_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 100)
|
46 |
-
for textline in textregion.findall(f".//{self.namespace}TextLine"):
|
47 |
-
coords = textline.find(f"{self.namespace}Coords").attrib["points"].split()
|
48 |
-
points = [tuple(map(int, point.split(","))) for point in coords]
|
49 |
-
|
50 |
-
poly_image = Image.new("RGBA", image.size)
|
51 |
-
poly_draw = ImageDraw.Draw(poly_image)
|
52 |
-
poly_draw.polygon(points, fill=fill_color)
|
53 |
-
|
54 |
-
text = textline.find(f"{self.namespace}TextEquiv").find(f"{self.namespace}Unicode").text
|
55 |
-
|
56 |
-
x_coords, y_coords = zip(*points)
|
57 |
-
min_x, max_x = min(x_coords), max(x_coords)
|
58 |
-
min_y = min(y_coords)
|
59 |
-
text_width, text_height = poly_draw.textsize(text, font=font) # Get text size
|
60 |
-
text_position = (
|
61 |
-
(min_x + max_x) // 2 - text_width // 2,
|
62 |
-
min_y + text_offset,
|
63 |
-
) # Center text horizontally
|
64 |
-
|
65 |
-
poly_draw.text(text_position, text, fill=(0, 0, 0), font=font)
|
66 |
-
image = Image.alpha_composite(image, poly_image)
|
67 |
-
|
68 |
-
return image
|
69 |
-
|
70 |
def xml_to_txt(self, output_file="page_txt.txt"):
|
71 |
with open(output_file, "w", encoding="utf-8") as f:
|
72 |
for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
|
|
|
|
|
|
|
|
|
1 |
import xml.etree.ElementTree as ET
|
2 |
|
|
|
|
|
3 |
|
4 |
class XmlParser:
|
5 |
def __init__(self, page_xml="./page_xml.xml"):
|
|
|
7 |
self.root = self.tree.getroot()
|
8 |
self.namespace = "{http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15}"
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def xml_to_txt(self, output_file="page_txt.txt"):
|
11 |
with open(output_file, "w", encoding="utf-8") as f:
|
12 |
for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
|
src/htr_pipeline/utils/pipeline_inferencer.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
|
3 |
+
from src.htr_pipeline.utils.process_segmask import SegMaskHelper
|
4 |
+
from src.htr_pipeline.utils.xml_helper import XMLHelper
|
5 |
+
|
6 |
+
|
7 |
+
class PipelineInferencer:
|
8 |
+
def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper):
|
9 |
+
self.process_seg_mask = process_seg_mask
|
10 |
+
self.xml_helper = xml_helper
|
11 |
+
|
12 |
+
def image_to_page_xml(
|
13 |
+
self, image, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold, inferencer
|
14 |
+
):
|
15 |
+
template_data = self.xml_helper.prepare_template_data(self.xml_helper.xml_file_name, image)
|
16 |
+
template_data["textRegions"] = self._process_regions(
|
17 |
+
image, inferencer, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold
|
18 |
+
)
|
19 |
+
|
20 |
+
print(template_data)
|
21 |
+
return self.xml_helper.render(template_data)
|
22 |
+
|
23 |
+
def _process_regions(
|
24 |
+
self,
|
25 |
+
image,
|
26 |
+
inferencer,
|
27 |
+
pred_score_threshold_regions,
|
28 |
+
pred_score_threshold_lines,
|
29 |
+
containments_threshold,
|
30 |
+
htr_threshold=0.7,
|
31 |
+
):
|
32 |
+
_, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
|
33 |
+
image,
|
34 |
+
pred_score_threshold=pred_score_threshold_regions,
|
35 |
+
containments_threshold=containments_threshold,
|
36 |
+
visualize=False,
|
37 |
+
)
|
38 |
+
|
39 |
+
region_data_list = []
|
40 |
+
for i, data in tqdm(enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))):
|
41 |
+
region_data = self._create_region_data(
|
42 |
+
data, i, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
|
43 |
+
)
|
44 |
+
if region_data:
|
45 |
+
region_data_list.append(region_data)
|
46 |
+
|
47 |
+
return region_data_list
|
48 |
+
|
49 |
+
def _create_region_data(
|
50 |
+
self, data, index, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
|
51 |
+
):
|
52 |
+
text_region, reg_pol, mask = data
|
53 |
+
region_data = {"id": f"region_{index}", "boundary": reg_pol}
|
54 |
+
|
55 |
+
text_lines, htr_scores = self._process_lines(
|
56 |
+
text_region,
|
57 |
+
inferencer,
|
58 |
+
pred_score_threshold_lines,
|
59 |
+
containments_threshold,
|
60 |
+
mask,
|
61 |
+
region_data["id"],
|
62 |
+
htr_threshold,
|
63 |
+
)
|
64 |
+
|
65 |
+
if not text_lines:
|
66 |
+
return None
|
67 |
+
|
68 |
+
region_data["textLines"] = text_lines
|
69 |
+
mean_htr_score = sum(htr_scores) / len(htr_scores) if htr_scores else 0
|
70 |
+
|
71 |
+
return region_data if mean_htr_score > htr_threshold else None
|
72 |
+
|
73 |
+
def _process_lines(
|
74 |
+
self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.7
|
75 |
+
):
|
76 |
+
_, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
|
77 |
+
text_region, pred_score_threshold, containments_threshold, visualize=False, custom_track=False
|
78 |
+
)
|
79 |
+
|
80 |
+
if not lines_cropped_ordered:
|
81 |
+
return None, []
|
82 |
+
|
83 |
+
line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered)
|
84 |
+
|
85 |
+
text_lines = []
|
86 |
+
htr_scores = []
|
87 |
+
for index, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
|
88 |
+
line_data, htr_score = self._create_line_data(line, line_pol, index, region_id, inferencer, htr_threshold)
|
89 |
+
|
90 |
+
if line_data:
|
91 |
+
text_lines.append(line_data)
|
92 |
+
htr_scores.append(htr_score)
|
93 |
+
|
94 |
+
return text_lines, htr_scores
|
95 |
+
|
96 |
+
def _create_line_data(self, line, line_pol, index, region_id, inferencer, htr_threshold):
|
97 |
+
line_data = {"id": f"line_{region_id}_{index}", "boundary": line_pol}
|
98 |
+
|
99 |
+
transcribed_text, htr_score = inferencer.transcribe(line)
|
100 |
+
line_data["unicode"] = self.xml_helper.escape_xml_chars(transcribed_text)
|
101 |
+
line_data["pred_score"] = round(htr_score, 4)
|
102 |
+
|
103 |
+
return line_data if htr_score > htr_threshold else None, htr_score
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
pass
|
src/htr_pipeline/utils/process_xml.py
DELETED
@@ -1,167 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
from datetime import datetime
|
4 |
-
|
5 |
-
import jinja2
|
6 |
-
from tqdm import tqdm
|
7 |
-
|
8 |
-
from src.htr_pipeline.inferencer import InferencerInterface
|
9 |
-
from src.htr_pipeline.utils.process_segmask import SegMaskHelper
|
10 |
-
|
11 |
-
|
12 |
-
class XMLHelper:
|
13 |
-
def __init__(self):
|
14 |
-
self.process_seg_mask = SegMaskHelper()
|
15 |
-
|
16 |
-
def image_to_page_xml(
|
17 |
-
self,
|
18 |
-
image,
|
19 |
-
pred_score_threshold_regions,
|
20 |
-
pred_score_threshold_lines,
|
21 |
-
containments_threshold,
|
22 |
-
inferencer: InferencerInterface,
|
23 |
-
xml_file_name="page_xml.xml",
|
24 |
-
):
|
25 |
-
img_height = image.shape[0]
|
26 |
-
img_width = image.shape[1]
|
27 |
-
img_file_name = xml_file_name
|
28 |
-
|
29 |
-
template_data = self.prepare_template_data(img_file_name, img_width, img_height)
|
30 |
-
|
31 |
-
template_data["textRegions"] = self._process_regions(
|
32 |
-
image,
|
33 |
-
inferencer,
|
34 |
-
pred_score_threshold_regions,
|
35 |
-
pred_score_threshold_lines,
|
36 |
-
containments_threshold,
|
37 |
-
)
|
38 |
-
|
39 |
-
rendered_xml = self._render_xml(template_data)
|
40 |
-
|
41 |
-
return rendered_xml
|
42 |
-
|
43 |
-
def _transform_coords(self, input_string):
|
44 |
-
pattern = r"\[\s*([^\s,]+)\s*,\s*([^\s\]]+)\s*\]"
|
45 |
-
replacement = r"\1,\2"
|
46 |
-
return re.sub(pattern, replacement, input_string)
|
47 |
-
|
48 |
-
def _render_xml(self, template_data):
|
49 |
-
template_loader = jinja2.FileSystemLoader(searchpath="./src/htr_pipeline/utils/templates")
|
50 |
-
template_env = jinja2.Environment(loader=template_loader, trim_blocks=True)
|
51 |
-
template = template_env.get_template("page_xml_2013.xml")
|
52 |
-
rendered_xml = template.render(template_data)
|
53 |
-
rendered_xml = self._transform_coords(rendered_xml)
|
54 |
-
return rendered_xml
|
55 |
-
|
56 |
-
def prepare_template_data(self, img_file_name, img_width, img_height):
|
57 |
-
now = datetime.now()
|
58 |
-
date_time = now.strftime("%Y-%m-%d, %H:%M:%S")
|
59 |
-
return {
|
60 |
-
"created": date_time,
|
61 |
-
"imageFilename": img_file_name,
|
62 |
-
"imageWidth": img_width,
|
63 |
-
"imageHeight": img_height,
|
64 |
-
"textRegions": list(),
|
65 |
-
}
|
66 |
-
|
67 |
-
def _process_regions(
|
68 |
-
self,
|
69 |
-
image,
|
70 |
-
inferencer: InferencerInterface,
|
71 |
-
pred_score_threshold_regions,
|
72 |
-
pred_score_threshold_lines,
|
73 |
-
containments_threshold,
|
74 |
-
htr_threshold=0.7,
|
75 |
-
):
|
76 |
-
_, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
|
77 |
-
image,
|
78 |
-
pred_score_threshold=pred_score_threshold_regions,
|
79 |
-
containments_threshold=containments_threshold,
|
80 |
-
visualize=False,
|
81 |
-
)
|
82 |
-
|
83 |
-
region_data_list = []
|
84 |
-
for i, (text_region, reg_pol, mask) in tqdm(
|
85 |
-
enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))
|
86 |
-
):
|
87 |
-
region_id = "region_" + str(i)
|
88 |
-
region_data = dict()
|
89 |
-
region_data["id"] = region_id
|
90 |
-
region_data["boundary"] = reg_pol
|
91 |
-
|
92 |
-
text_lines, htr_scores = self._process_lines(
|
93 |
-
text_region,
|
94 |
-
inferencer,
|
95 |
-
pred_score_threshold_lines,
|
96 |
-
containments_threshold,
|
97 |
-
mask,
|
98 |
-
region_id,
|
99 |
-
)
|
100 |
-
|
101 |
-
if text_lines is None:
|
102 |
-
continue
|
103 |
-
|
104 |
-
region_data["textLines"] = text_lines
|
105 |
-
mean_htr_score = sum(htr_scores) / len(htr_scores)
|
106 |
-
|
107 |
-
if mean_htr_score > htr_threshold:
|
108 |
-
region_data_list.append(region_data)
|
109 |
-
|
110 |
-
return region_data_list
|
111 |
-
|
112 |
-
def _process_lines(
|
113 |
-
self,
|
114 |
-
text_region,
|
115 |
-
inferencer: InferencerInterface,
|
116 |
-
pred_score_threshold_lines,
|
117 |
-
containments_threshold,
|
118 |
-
mask,
|
119 |
-
region_id,
|
120 |
-
htr_threshold=0.7,
|
121 |
-
):
|
122 |
-
_, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
|
123 |
-
text_region,
|
124 |
-
pred_score_threshold=pred_score_threshold_lines,
|
125 |
-
containments_threshold=containments_threshold,
|
126 |
-
visualize=False,
|
127 |
-
custom_track=False,
|
128 |
-
)
|
129 |
-
|
130 |
-
if lines_cropped_ordered is None:
|
131 |
-
return None, None
|
132 |
-
|
133 |
-
line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered)
|
134 |
-
|
135 |
-
htr_scores = list()
|
136 |
-
text_lines = list()
|
137 |
-
|
138 |
-
for j, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
|
139 |
-
line_id = "line_" + region_id + "_" + str(j)
|
140 |
-
line_data = dict()
|
141 |
-
line_data["id"] = line_id
|
142 |
-
line_data["boundary"] = line_pol
|
143 |
-
|
144 |
-
transcribed_text, htr_score = inferencer.transcribe(line)
|
145 |
-
escaped_text = self._escape_xml_chars(transcribed_text)
|
146 |
-
line_data["unicode"] = escaped_text
|
147 |
-
line_data["pred_score"] = round(htr_score, 4)
|
148 |
-
|
149 |
-
htr_scores.append(htr_score)
|
150 |
-
|
151 |
-
if htr_score > htr_threshold:
|
152 |
-
text_lines.append(line_data)
|
153 |
-
|
154 |
-
return text_lines, htr_scores
|
155 |
-
|
156 |
-
def _escape_xml_chars(self, textline):
|
157 |
-
return (
|
158 |
-
textline.replace("&", "&")
|
159 |
-
.replace("<", "<")
|
160 |
-
.replace(">", ">")
|
161 |
-
.replace("'", "'")
|
162 |
-
.replace('"', """)
|
163 |
-
)
|
164 |
-
|
165 |
-
|
166 |
-
if __name__ == "__main__":
|
167 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/htr_pipeline/utils/visualize_xml.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import xml.etree.ElementTree as ET
|
3 |
+
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
|
6 |
+
|
7 |
+
class XmlViz:
|
8 |
+
def __init__(self, page_xml="./page_xml.xml"):
|
9 |
+
self.tree = ET.parse(page_xml, parser=ET.XMLParser(encoding="utf-8"))
|
10 |
+
self.root = self.tree.getroot()
|
11 |
+
self.namespace = "{http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15}"
|
12 |
+
|
13 |
+
def visualize_xml(
|
14 |
+
self,
|
15 |
+
background_image,
|
16 |
+
font_size=9,
|
17 |
+
text_offset=10,
|
18 |
+
font_path_tff="./src/htr_pipeline/utils/templates/arial.ttf",
|
19 |
+
):
|
20 |
+
image = Image.fromarray(background_image).convert("RGBA")
|
21 |
+
|
22 |
+
text_offset = -text_offset
|
23 |
+
base_font_size = font_size
|
24 |
+
font_path = font_path_tff
|
25 |
+
|
26 |
+
max_bbox_width = 0 # Initialize maximum bounding box width
|
27 |
+
|
28 |
+
for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
|
29 |
+
coords = textregion.find(f"{self.namespace}Coords").attrib["points"].split()
|
30 |
+
points = [tuple(map(int, point.split(","))) for point in coords]
|
31 |
+
x_coords, y_coords = zip(*points)
|
32 |
+
min_x, max_x = min(x_coords), max(x_coords)
|
33 |
+
bbox_width = max_x - min_x # Width of the current bounding box
|
34 |
+
max_bbox_width = max(max_bbox_width, bbox_width) # Update maximum bounding box width
|
35 |
+
|
36 |
+
scaling_factor = max_bbox_width / 400.0 # Use maximum bounding box width for scaling
|
37 |
+
font_size_scaled = int(base_font_size * scaling_factor)
|
38 |
+
font = ImageFont.truetype(font_path, font_size_scaled)
|
39 |
+
|
40 |
+
for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
|
41 |
+
fill_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 100)
|
42 |
+
for textline in textregion.findall(f".//{self.namespace}TextLine"):
|
43 |
+
coords = textline.find(f"{self.namespace}Coords").attrib["points"].split()
|
44 |
+
points = [tuple(map(int, point.split(","))) for point in coords]
|
45 |
+
|
46 |
+
poly_image = Image.new("RGBA", image.size)
|
47 |
+
poly_draw = ImageDraw.Draw(poly_image)
|
48 |
+
poly_draw.polygon(points, fill=fill_color)
|
49 |
+
|
50 |
+
text = textline.find(f"{self.namespace}TextEquiv").find(f"{self.namespace}Unicode").text
|
51 |
+
|
52 |
+
x_coords, y_coords = zip(*points)
|
53 |
+
min_x, max_x = min(x_coords), max(x_coords)
|
54 |
+
min_y = min(y_coords)
|
55 |
+
text_width, text_height = poly_draw.textsize(text, font=font) # Get text size
|
56 |
+
text_position = (
|
57 |
+
(min_x + max_x) // 2 - text_width // 2,
|
58 |
+
min_y + text_offset,
|
59 |
+
) # Center text horizontally
|
60 |
+
|
61 |
+
poly_draw.text(text_position, text, fill=(0, 0, 0), font=font)
|
62 |
+
image = Image.alpha_composite(image, poly_image)
|
63 |
+
|
64 |
+
return image
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
pass
|
src/htr_pipeline/utils/xml_helper.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from datetime import datetime
|
3 |
+
|
4 |
+
import jinja2
|
5 |
+
|
6 |
+
|
7 |
+
class XMLHelper:
|
8 |
+
def __init__(self, xml_file_name="page_xml.xml"):
|
9 |
+
self.xml_file_name = xml_file_name
|
10 |
+
self.searchpath = "./src/htr_pipeline/utils/templates"
|
11 |
+
self.template = "page_xml_2013.xml"
|
12 |
+
|
13 |
+
def render(self, template_data):
|
14 |
+
rendered_xml = self._render_xml(template_data)
|
15 |
+
return rendered_xml
|
16 |
+
|
17 |
+
def _transform_coords(self, input_string):
|
18 |
+
pattern = r"\[\s*([^\s,]+)\s*,\s*([^\s\]]+)\s*\]"
|
19 |
+
replacement = r"\1,\2"
|
20 |
+
return re.sub(pattern, replacement, input_string)
|
21 |
+
|
22 |
+
def _render_xml(self, template_data):
|
23 |
+
template_loader = jinja2.FileSystemLoader(searchpath=self.searchpath)
|
24 |
+
template_env = jinja2.Environment(loader=template_loader, trim_blocks=True)
|
25 |
+
template = template_env.get_template(self.template)
|
26 |
+
rendered_xml = template.render(template_data)
|
27 |
+
rendered_xml = self._transform_coords(rendered_xml)
|
28 |
+
return rendered_xml
|
29 |
+
|
30 |
+
def prepare_template_data(self, img_file_name, image):
|
31 |
+
img_height = image.shape[0]
|
32 |
+
img_width = image.shape[1]
|
33 |
+
|
34 |
+
now = datetime.now()
|
35 |
+
date_time = now.strftime("%Y-%m-%d, %H:%M:%S")
|
36 |
+
return {
|
37 |
+
"created": date_time,
|
38 |
+
"imageFilename": img_file_name,
|
39 |
+
"imageWidth": img_width,
|
40 |
+
"imageHeight": img_height,
|
41 |
+
"textRegions": list(),
|
42 |
+
}
|
43 |
+
|
44 |
+
def escape_xml_chars(self, textline):
|
45 |
+
return (
|
46 |
+
textline.replace("&", "&")
|
47 |
+
.replace("<", "<")
|
48 |
+
.replace(">", ">")
|
49 |
+
.replace("'", "'")
|
50 |
+
.replace('"', """)
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
pass
|
tabs/htr_tool.py
CHANGED
@@ -19,32 +19,17 @@ with gr.Blocks() as htr_tool_tab:
|
|
19 |
)
|
20 |
|
21 |
with gr.Row():
|
22 |
-
# with gr.Group():
|
23 |
-
# callback = gr.CSVLogger()
|
24 |
-
# # hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "htr_pipelin_flags")
|
25 |
-
# flagging_button = gr.Button(
|
26 |
-
# "Flag",
|
27 |
-
# variant="secondary",
|
28 |
-
# visible=True,
|
29 |
-
# ).style(full_width=True)
|
30 |
-
# radio_file_input = gr.Radio(
|
31 |
-
# value="Text file", choices=["Text file ", "Page XML file "], label="What kind file output?"
|
32 |
-
# )
|
33 |
-
|
34 |
radio_file_input = gr.CheckboxGroup(
|
35 |
choices=["Txt", "XML"],
|
36 |
-
value=["
|
37 |
label="Output file extension",
|
38 |
# info="Only txt and page xml is supported for now!",
|
|
|
39 |
)
|
40 |
|
41 |
htr_pipeline_button = gr.Button(
|
42 |
-
"Run HTR",
|
43 |
-
|
44 |
-
visible=True,
|
45 |
-
elem_id="run_pipeline_button",
|
46 |
-
).style(full_width=False)
|
47 |
-
|
48 |
with gr.Group():
|
49 |
with gr.Row():
|
50 |
fast_file_downlod = gr.File(label="Download output file", visible=False)
|
@@ -75,11 +60,7 @@ with gr.Blocks() as htr_tool_tab:
|
|
75 |
fast_track_output_image = gr.Image(label="HTR results visualizer", type="numpy", tool="editor", height=650)
|
76 |
|
77 |
with gr.Row(visible=False) as api_placeholder:
|
78 |
-
htr_pipeline_button_api = gr.Button(
|
79 |
-
"Run pipeline",
|
80 |
-
variant="primary",
|
81 |
-
visible=False,
|
82 |
-
).style(full_width=False)
|
83 |
|
84 |
xml_rendered_placeholder_for_api = gr.Textbox(visible=False)
|
85 |
htr_pipeline_button.click(
|
@@ -94,8 +75,3 @@ with gr.Blocks() as htr_tool_tab:
|
|
94 |
outputs=[xml_rendered_placeholder_for_api],
|
95 |
api_name="predict",
|
96 |
)
|
97 |
-
|
98 |
-
# callback.setup([fast_track_input_region_image], "flagged_data_points")
|
99 |
-
# flagging_button.click(lambda *args: callback.flag(args), [fast_track_input_region_image], None, preprocess=False)
|
100 |
-
# flagging_button.click(lambda: (gr.update(value="Flagged")), outputs=flagging_button)
|
101 |
-
# fast_track_input_region_image.change(lambda: (gr.update(value="Flag")), outputs=flagging_button)
|
|
|
19 |
)
|
20 |
|
21 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
radio_file_input = gr.CheckboxGroup(
|
23 |
choices=["Txt", "XML"],
|
24 |
+
value=["XML"],
|
25 |
label="Output file extension",
|
26 |
# info="Only txt and page xml is supported for now!",
|
27 |
+
scale=1,
|
28 |
)
|
29 |
|
30 |
htr_pipeline_button = gr.Button(
|
31 |
+
"Run HTR", variant="primary", visible=True, elem_id="run_pipeline_button", scale=1
|
32 |
+
)
|
|
|
|
|
|
|
|
|
33 |
with gr.Group():
|
34 |
with gr.Row():
|
35 |
fast_file_downlod = gr.File(label="Download output file", visible=False)
|
|
|
60 |
fast_track_output_image = gr.Image(label="HTR results visualizer", type="numpy", tool="editor", height=650)
|
61 |
|
62 |
with gr.Row(visible=False) as api_placeholder:
|
63 |
+
htr_pipeline_button_api = gr.Button("Run pipeline", variant="primary", visible=False, scale=1)
|
|
|
|
|
|
|
|
|
64 |
|
65 |
xml_rendered_placeholder_for_api = gr.Textbox(visible=False)
|
66 |
htr_pipeline_button.click(
|
|
|
75 |
outputs=[xml_rendered_placeholder_for_api],
|
76 |
api_name="predict",
|
77 |
)
|
|
|
|
|
|
|
|
|
|
tabs/stepwise_htr_tool.py
CHANGED
@@ -25,7 +25,8 @@ with gr.Blocks() as stepwise_htr_tool_tab:
|
|
25 |
label="Image to Region segment",
|
26 |
# type="numpy",
|
27 |
tool="editor",
|
28 |
-
|
|
|
29 |
|
30 |
with gr.Accordion("Region segment settings:", open=False):
|
31 |
with gr.Row():
|
@@ -63,7 +64,7 @@ with gr.Blocks() as stepwise_htr_tool_tab:
|
|
63 |
"Segment Region",
|
64 |
variant="primary",
|
65 |
elem_id="region_segment_button",
|
66 |
-
)
|
67 |
|
68 |
with gr.Row():
|
69 |
with gr.Accordion("Example images to use:", open=False) as example_accord:
|
@@ -75,7 +76,7 @@ with gr.Blocks() as stepwise_htr_tool_tab:
|
|
75 |
)
|
76 |
|
77 |
with gr.Column(scale=3):
|
78 |
-
output_region_image = gr.Image(label="Segmented regions", type="numpy"
|
79 |
|
80 |
##############################################
|
81 |
with gr.Tab("2. Line Segmentation"):
|
@@ -84,27 +85,27 @@ with gr.Blocks() as stepwise_htr_tool_tab:
|
|
84 |
# type="numpy",
|
85 |
interactive="False",
|
86 |
visible=True,
|
87 |
-
|
|
|
88 |
|
89 |
with gr.Row(visible=False) as control_line_segment:
|
90 |
with gr.Column(scale=2):
|
91 |
with gr.Box():
|
92 |
regions_cropped_gallery = gr.Gallery(
|
93 |
label="Segmented regions",
|
94 |
-
show_label=False,
|
95 |
elem_id="gallery",
|
96 |
-
).style(
|
97 |
columns=[2],
|
98 |
rows=[2],
|
99 |
# object_fit="contain",
|
100 |
-
height=
|
101 |
preview=True,
|
102 |
container=False,
|
103 |
)
|
104 |
|
105 |
input_region_from_gallery = gr.Image(
|
106 |
-
label="Region segmentation to line segment", interactive="False", visible=False
|
107 |
-
)
|
|
|
108 |
with gr.Row():
|
109 |
with gr.Accordion("Line segment settings:", open=False):
|
110 |
with gr.Row():
|
@@ -126,7 +127,7 @@ with gr.Blocks() as stepwise_htr_tool_tab:
|
|
126 |
info="""The minimum required overlap or similarity
|
127 |
for a detected region or object to be considered valid""",
|
128 |
)
|
129 |
-
with gr.Row(
|
130 |
line_segment_model_dropdown = gr.Dropdown(
|
131 |
choices=["Riksarkivet/RmtDet_lines"],
|
132 |
value="Riksarkivet/RmtDet_lines",
|
@@ -138,22 +139,22 @@ with gr.Blocks() as stepwise_htr_tool_tab:
|
|
138 |
" ",
|
139 |
variant="Secondary",
|
140 |
# elem_id="center_button",
|
141 |
-
|
|
|
142 |
|
143 |
line_segment_button = gr.Button(
|
144 |
"Segment Lines",
|
145 |
variant="primary",
|
146 |
# elem_id="center_button",
|
147 |
-
|
|
|
148 |
|
149 |
with gr.Column(scale=3):
|
150 |
# gr.Markdown("""lorem ipsum""")
|
151 |
|
152 |
output_line_from_region = gr.Image(
|
153 |
-
label="Segmented lines",
|
154 |
-
|
155 |
-
interactive="False",
|
156 |
-
).style(height=600)
|
157 |
|
158 |
###############################################
|
159 |
with gr.Tab("3. Transcribe Text"):
|
@@ -162,19 +163,16 @@ with gr.Blocks() as stepwise_htr_tool_tab:
|
|
162 |
# type="numpy",
|
163 |
interactive="False",
|
164 |
visible=True,
|
165 |
-
|
|
|
166 |
|
167 |
with gr.Row(visible=False) as control_htr:
|
168 |
inputs_lines_to_transcribe = gr.Variable()
|
169 |
|
170 |
with gr.Column(scale=2):
|
171 |
image_inputs_lines_to_transcribe = gr.Image(
|
172 |
-
label="Transcribed lines",
|
173 |
-
|
174 |
-
interactive="False",
|
175 |
-
visible=False,
|
176 |
-
).style(height=470)
|
177 |
-
|
178 |
with gr.Row():
|
179 |
with gr.Accordion("Transcribe settings:", open=False):
|
180 |
transcriber_model = gr.Dropdown(
|
@@ -184,30 +182,21 @@ with gr.Blocks() as stepwise_htr_tool_tab:
|
|
184 |
info="Will add more models later!",
|
185 |
)
|
186 |
with gr.Row():
|
187 |
-
clear_transcribe_button = gr.Button(" ", variant="Secondary", visible=True)
|
188 |
-
full_width=True
|
189 |
-
)
|
190 |
-
transcribe_button = gr.Button("Transcribe lines", variant="primary", visible=True).style(
|
191 |
-
full_width=True
|
192 |
-
)
|
193 |
|
194 |
-
|
195 |
-
full_width=True
|
196 |
-
)
|
197 |
-
|
198 |
-
with gr.Row():
|
199 |
-
txt_file_downlod = gr.File(label="Download text", visible=False)
|
200 |
|
201 |
with gr.Column(scale=3):
|
202 |
with gr.Row():
|
203 |
transcribed_text_df = gr.Dataframe(
|
204 |
headers=["Transcribed text"],
|
205 |
-
max_rows=
|
206 |
col_count=(1, "fixed"),
|
207 |
wrap=True,
|
208 |
interactive=False,
|
209 |
overflow_row_behaviour="paginate",
|
210 |
-
|
|
|
211 |
|
212 |
#####################################
|
213 |
with gr.Tab("4. Explore Results"):
|
@@ -216,35 +205,43 @@ with gr.Blocks() as stepwise_htr_tool_tab:
|
|
216 |
# type="numpy",
|
217 |
interactive="False",
|
218 |
visible=True,
|
219 |
-
|
|
|
220 |
|
221 |
-
with gr.Row(visible=False) as control_results_transcribe:
|
222 |
with gr.Column(scale=1, visible=True):
|
223 |
with gr.Box():
|
224 |
temp_gallery_input = gr.Variable()
|
225 |
|
226 |
gallery_inputs_lines_to_transcribe = gr.Gallery(
|
227 |
label="Cropped transcribed lines",
|
228 |
-
show_label=True,
|
229 |
elem_id="gallery_lines",
|
230 |
-
).style(
|
231 |
columns=[3],
|
232 |
rows=[3],
|
233 |
# object_fit="contain",
|
234 |
-
|
235 |
preview=True,
|
236 |
container=False,
|
237 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
with gr.Column(scale=1, visible=True):
|
239 |
mapping_dict = gr.Variable()
|
240 |
transcribed_text_df_finish = gr.Dataframe(
|
241 |
-
headers=["Transcribed text", "
|
242 |
-
max_rows=
|
243 |
col_count=(2, "fixed"),
|
244 |
wrap=True,
|
245 |
interactive=False,
|
246 |
overflow_row_behaviour="paginate",
|
247 |
-
|
|
|
248 |
|
249 |
# custom track
|
250 |
region_segment_button.click(
|
@@ -260,7 +257,7 @@ with gr.Blocks() as stepwise_htr_tool_tab:
|
|
260 |
transcribed_text_df_finish.select(
|
261 |
fn=custom_track.get_select_index_df,
|
262 |
inputs=[transcribed_text_df_finish, mapping_dict],
|
263 |
-
outputs=gallery_inputs_lines_to_transcribe,
|
264 |
)
|
265 |
|
266 |
line_segment_button.click(
|
@@ -287,23 +284,12 @@ with gr.Blocks() as stepwise_htr_tool_tab:
|
|
287 |
transcribed_text_df,
|
288 |
transcribed_text_df_finish,
|
289 |
mapping_dict,
|
290 |
-
|
291 |
control_results_transcribe,
|
292 |
image_placeholder_explore_results,
|
293 |
],
|
294 |
)
|
295 |
|
296 |
-
donwload_txt_button.click(
|
297 |
-
custom_track.download_df_to_txt,
|
298 |
-
inputs=transcribed_text_df,
|
299 |
-
outputs=[txt_file_downlod, txt_file_downlod],
|
300 |
-
)
|
301 |
-
|
302 |
-
# def remove_temp_vis():
|
303 |
-
# if os.path.exists("./vis_data"):
|
304 |
-
# os.remove("././vis_data")
|
305 |
-
# return None
|
306 |
-
|
307 |
clear_button.click(
|
308 |
lambda: (
|
309 |
(shutil.rmtree("./vis_data") if os.path.exists("./vis_data") else None, None)[1],
|
|
|
25 |
label="Image to Region segment",
|
26 |
# type="numpy",
|
27 |
tool="editor",
|
28 |
+
height=350,
|
29 |
+
)
|
30 |
|
31 |
with gr.Accordion("Region segment settings:", open=False):
|
32 |
with gr.Row():
|
|
|
64 |
"Segment Region",
|
65 |
variant="primary",
|
66 |
elem_id="region_segment_button",
|
67 |
+
)
|
68 |
|
69 |
with gr.Row():
|
70 |
with gr.Accordion("Example images to use:", open=False) as example_accord:
|
|
|
76 |
)
|
77 |
|
78 |
with gr.Column(scale=3):
|
79 |
+
output_region_image = gr.Image(label="Segmented regions", type="numpy", height=600)
|
80 |
|
81 |
##############################################
|
82 |
with gr.Tab("2. Line Segmentation"):
|
|
|
85 |
# type="numpy",
|
86 |
interactive="False",
|
87 |
visible=True,
|
88 |
+
height=600,
|
89 |
+
)
|
90 |
|
91 |
with gr.Row(visible=False) as control_line_segment:
|
92 |
with gr.Column(scale=2):
|
93 |
with gr.Box():
|
94 |
regions_cropped_gallery = gr.Gallery(
|
95 |
label="Segmented regions",
|
|
|
96 |
elem_id="gallery",
|
|
|
97 |
columns=[2],
|
98 |
rows=[2],
|
99 |
# object_fit="contain",
|
100 |
+
height=450,
|
101 |
preview=True,
|
102 |
container=False,
|
103 |
)
|
104 |
|
105 |
input_region_from_gallery = gr.Image(
|
106 |
+
label="Region segmentation to line segment", interactive="False", visible=False, height=400
|
107 |
+
)
|
108 |
+
|
109 |
with gr.Row():
|
110 |
with gr.Accordion("Line segment settings:", open=False):
|
111 |
with gr.Row():
|
|
|
127 |
info="""The minimum required overlap or similarity
|
128 |
for a detected region or object to be considered valid""",
|
129 |
)
|
130 |
+
with gr.Row(equal_height=False):
|
131 |
line_segment_model_dropdown = gr.Dropdown(
|
132 |
choices=["Riksarkivet/RmtDet_lines"],
|
133 |
value="Riksarkivet/RmtDet_lines",
|
|
|
139 |
" ",
|
140 |
variant="Secondary",
|
141 |
# elem_id="center_button",
|
142 |
+
scale=1,
|
143 |
+
)
|
144 |
|
145 |
line_segment_button = gr.Button(
|
146 |
"Segment Lines",
|
147 |
variant="primary",
|
148 |
# elem_id="center_button",
|
149 |
+
scale=1,
|
150 |
+
)
|
151 |
|
152 |
with gr.Column(scale=3):
|
153 |
# gr.Markdown("""lorem ipsum""")
|
154 |
|
155 |
output_line_from_region = gr.Image(
|
156 |
+
label="Segmented lines", type="numpy", interactive="False", height=600
|
157 |
+
)
|
|
|
|
|
158 |
|
159 |
###############################################
|
160 |
with gr.Tab("3. Transcribe Text"):
|
|
|
163 |
# type="numpy",
|
164 |
interactive="False",
|
165 |
visible=True,
|
166 |
+
height=600,
|
167 |
+
)
|
168 |
|
169 |
with gr.Row(visible=False) as control_htr:
|
170 |
inputs_lines_to_transcribe = gr.Variable()
|
171 |
|
172 |
with gr.Column(scale=2):
|
173 |
image_inputs_lines_to_transcribe = gr.Image(
|
174 |
+
label="Transcribed lines", type="numpy", interactive="False", visible=False, height=470
|
175 |
+
)
|
|
|
|
|
|
|
|
|
176 |
with gr.Row():
|
177 |
with gr.Accordion("Transcribe settings:", open=False):
|
178 |
transcriber_model = gr.Dropdown(
|
|
|
182 |
info="Will add more models later!",
|
183 |
)
|
184 |
with gr.Row():
|
185 |
+
clear_transcribe_button = gr.Button(" ", variant="Secondary", visible=True, scale=1)
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
+
transcribe_button = gr.Button("Transcribe Lines", variant="primary", visible=True, scale=1)
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
with gr.Column(scale=3):
|
190 |
with gr.Row():
|
191 |
transcribed_text_df = gr.Dataframe(
|
192 |
headers=["Transcribed text"],
|
193 |
+
max_rows=14,
|
194 |
col_count=(1, "fixed"),
|
195 |
wrap=True,
|
196 |
interactive=False,
|
197 |
overflow_row_behaviour="paginate",
|
198 |
+
height=600,
|
199 |
+
)
|
200 |
|
201 |
#####################################
|
202 |
with gr.Tab("4. Explore Results"):
|
|
|
205 |
# type="numpy",
|
206 |
interactive="False",
|
207 |
visible=True,
|
208 |
+
height=600,
|
209 |
+
)
|
210 |
|
211 |
+
with gr.Row(visible=False, equal_height=False) as control_results_transcribe:
|
212 |
with gr.Column(scale=1, visible=True):
|
213 |
with gr.Box():
|
214 |
temp_gallery_input = gr.Variable()
|
215 |
|
216 |
gallery_inputs_lines_to_transcribe = gr.Gallery(
|
217 |
label="Cropped transcribed lines",
|
|
|
218 |
elem_id="gallery_lines",
|
|
|
219 |
columns=[3],
|
220 |
rows=[3],
|
221 |
# object_fit="contain",
|
222 |
+
height=300,
|
223 |
preview=True,
|
224 |
container=False,
|
225 |
)
|
226 |
+
|
227 |
+
dataframe_text_index = gr.Textbox(
|
228 |
+
label="Text from DataFrame selection",
|
229 |
+
info="Click on a dataframe cell to view the corresponding transcribed text line crop. You can also sort the dataframe to easily locate specific entries.",
|
230 |
+
lines=2,
|
231 |
+
interactive=False,
|
232 |
+
)
|
233 |
+
|
234 |
with gr.Column(scale=1, visible=True):
|
235 |
mapping_dict = gr.Variable()
|
236 |
transcribed_text_df_finish = gr.Dataframe(
|
237 |
+
headers=["Transcribed text", "pred score"],
|
238 |
+
max_rows=14,
|
239 |
col_count=(2, "fixed"),
|
240 |
wrap=True,
|
241 |
interactive=False,
|
242 |
overflow_row_behaviour="paginate",
|
243 |
+
height=600,
|
244 |
+
)
|
245 |
|
246 |
# custom track
|
247 |
region_segment_button.click(
|
|
|
257 |
transcribed_text_df_finish.select(
|
258 |
fn=custom_track.get_select_index_df,
|
259 |
inputs=[transcribed_text_df_finish, mapping_dict],
|
260 |
+
outputs=[gallery_inputs_lines_to_transcribe, dataframe_text_index],
|
261 |
)
|
262 |
|
263 |
line_segment_button.click(
|
|
|
284 |
transcribed_text_df,
|
285 |
transcribed_text_df_finish,
|
286 |
mapping_dict,
|
287 |
+
# Hide
|
288 |
control_results_transcribe,
|
289 |
image_placeholder_explore_results,
|
290 |
],
|
291 |
)
|
292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
clear_button.click(
|
294 |
lambda: (
|
295 |
(shutil.rmtree("./vis_data") if os.path.exists("./vis_data") else None, None)[1],
|