Spaces:
Starting
on
T4
Starting
on
T4
import os | |
import gradio as gr | |
from helper.examples.examples import DemoImages | |
from helper.utils import TrafficDataHandler | |
from src.htr_pipeline.gradio_backend import FastTrack, SingletonModelLoader | |
model_loader = SingletonModelLoader() | |
fast_track = FastTrack(model_loader) | |
images_for_demo = DemoImages() | |
terminate = False | |
with gr.Blocks() as htr_tool_tab: | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=2): | |
with gr.Row(): | |
fast_track_input_region_image = gr.Image( | |
label="Image to run HTR on", type="numpy", tool="editor", elem_id="image_upload", height=395 | |
) | |
with gr.Row(): | |
with gr.Tab("HTRFLOW") as tab_output_and_setting_selector: | |
with gr.Row(): | |
stop_htr_button = gr.Button( | |
value="Stop run", | |
variant="stop", | |
) | |
htr_pipeline_button = gr.Button( | |
"Run ", | |
variant="primary", | |
visible=True, | |
elem_id="run_pipeline_button", | |
) | |
htr_pipeline_button_var = gr.State(value="htr_pipeline_button") | |
htr_pipeline_button_api = gr.Button("Run pipeline", variant="primary", visible=False, scale=1) | |
fast_file_downlod = gr.File( | |
label="Download output file", visible=True, scale=1, height=100, elem_id="download_file" | |
) | |
with gr.Tab("Visualize") as tab_image_viewer_selector: | |
with gr.Row(): | |
gr.Markdown("") | |
run_image_visualizer_button = gr.Button( | |
value="Visualize results", variant="primary", interactive=True | |
) | |
selection_text_from_image_viewer = gr.Textbox( | |
interactive=False, label="Text Selector", info="Select a line on Image Viewer to return text" | |
) | |
with gr.Tab("Compare") as tab_model_compare_selector: | |
with gr.Box(): | |
gr.Markdown( | |
""" | |
**Work in progress** | |
Compare different runs with uploaded Ground Truth and calculate CER. You will also be able to upload output format files | |
""" | |
) | |
calc_cer_button_fast = gr.Button("Calculate CER", variant="primary", visible=True) | |
with gr.Column(scale=4): | |
with gr.Box(): | |
with gr.Row(visible=True) as output_and_setting_tab: | |
with gr.Column(scale=2): | |
fast_name_files_placeholder = gr.Markdown(visible=False) | |
gr.Examples( | |
examples=images_for_demo.examples_list, | |
inputs=[fast_name_files_placeholder, fast_track_input_region_image], | |
label="Example images", | |
examples_per_page=5, | |
) | |
gr.Markdown(" ") | |
with gr.Column(scale=3): | |
with gr.Group(): | |
gr.Markdown(" ⚙️ Settings ") | |
with gr.Row(): | |
radio_file_input = gr.CheckboxGroup( | |
choices=["Txt", "Page XML"], | |
value=["Txt", "Page XML"], | |
label="Output file extension", | |
info="JSON and ALTO-XML will be added", | |
scale=1, | |
) | |
with gr.Row(): | |
gr.Checkbox( | |
value=True, | |
label="Binarize image", | |
info="Binarize image to reduce background noise", | |
) | |
gr.Checkbox( | |
value=True, | |
label="Output prediction threshold", | |
info="Output XML with prediction score", | |
) | |
with gr.Accordion("Advanced settings", open=False): | |
with gr.Group(): | |
with gr.Row(): | |
htr_tool_region_segment_model_dropdown = gr.Dropdown( | |
choices=["Riksarkivet/rtmdet_region"], | |
value="Riksarkivet/rtmdet_region", | |
label="Region segmentation models", | |
info="More models will be added", | |
) | |
gr.Slider( | |
minimum=0.4, | |
maximum=1, | |
value=0.5, | |
step=0.05, | |
label="P-threshold", | |
info="""Filter confidence score for a prediction score to be considered""", | |
) | |
with gr.Row(): | |
htr_tool_line_segment_model_dropdown = gr.Dropdown( | |
choices=["Riksarkivet/rtmdet_lines"], | |
value="Riksarkivet/rtmdet_lines", | |
label="Line segmentation models", | |
info="More models will be added", | |
) | |
gr.Slider( | |
minimum=0.4, | |
maximum=1, | |
value=0.5, | |
step=0.05, | |
label="P-threshold", | |
info="""Filter confidence score for a prediction score to be considered""", | |
) | |
with gr.Row(): | |
htr_tool_transcriber_model_dropdown = gr.Dropdown( | |
choices=["Riksarkivet/satrn_htr", "microsoft/trocr-base-handwritten"], | |
value="Riksarkivet/satrn_htr", | |
label="Text recognition models", | |
info="More models will be added", | |
) | |
gr.Slider( | |
value=0.6, | |
minimum=0.5, | |
maximum=1, | |
label="HTR threshold", | |
info="Prediction score threshold for transcribed lines", | |
scale=1, | |
) | |
with gr.Row(): | |
gr.Markdown(" More settings will be added") | |
with gr.Row(visible=False) as image_viewer_tab: | |
text_polygon_dict = gr.Variable() | |
fast_track_output_image = gr.Image( | |
label="Image Viewer", type="numpy", height=600, interactive=False | |
) | |
with gr.Column(visible=False) as model_compare_selector: | |
gr.Markdown("**Work in progress:**") | |
with gr.Row(): | |
gr.Radio( | |
choices=["Compare Page XML", "Compare different runs"], | |
value="Compare Page XML", | |
info="Compare different runs from HTRFLOW or with external runs.", | |
) | |
with gr.Row(): | |
gr.UploadButton(label="Run A") | |
gr.UploadButton(label="Run B") | |
gr.UploadButton(label="Ground Truth") | |
with gr.Row(): | |
gr.HighlightedText( | |
label="Text diff runs", | |
combine_adjacent=True, | |
show_legend=True, | |
color_map={"+": "red", "-": "green"}, | |
) | |
with gr.Row(): | |
gr.HighlightedText( | |
label="Text diff ground truth", | |
combine_adjacent=True, | |
show_legend=True, | |
color_map={"+": "red", "-": "green"}, | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Row(equal_height=False): | |
cer_output_fast = gr.Textbox(label="CER:") | |
with gr.Column(scale=2): | |
pass | |
xml_rendered_placeholder_for_api = gr.Textbox(placeholder="XML", visible=False) | |
htr_event_click_event = htr_pipeline_button.click( | |
fast_track.segment_to_xml, | |
inputs=[fast_track_input_region_image, radio_file_input], | |
outputs=[fast_file_downlod, fast_file_downlod], | |
queue=False, | |
api_name=False, | |
) | |
htr_pipeline_button_api.click( | |
fast_track.segment_to_xml_api, | |
inputs=[fast_track_input_region_image], | |
outputs=[xml_rendered_placeholder_for_api], | |
queue=False, | |
api_name="run_htr_pipeline", | |
) | |
def dummy_update_htr_tool_transcriber_model_dropdown(htr_tool_transcriber_model_dropdown): | |
return gr.update(value="Riksarkivet/satrn_htr") | |
htr_tool_transcriber_model_dropdown.change( | |
fn=dummy_update_htr_tool_transcriber_model_dropdown, | |
inputs=htr_tool_transcriber_model_dropdown, | |
outputs=htr_tool_transcriber_model_dropdown, | |
queue=False, | |
api_name=False, | |
) | |
def update_selected_tab_output_and_setting(): | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
def update_selected_tab_image_viewer(): | |
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) | |
def update_selected_tab_model_compare(): | |
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
tab_output_and_setting_selector.select( | |
fn=update_selected_tab_output_and_setting, | |
outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector], | |
queue=False, | |
api_name=False, | |
) | |
tab_image_viewer_selector.select( | |
fn=update_selected_tab_image_viewer, | |
outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector], | |
queue=False, | |
api_name=False, | |
) | |
tab_model_compare_selector.select( | |
fn=update_selected_tab_model_compare, | |
outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector], | |
queue=False, | |
api_name=False, | |
) | |
def stop_function(): | |
from src.htr_pipeline.utils import pipeline_inferencer | |
pipeline_inferencer.terminate = True | |
gr.Info("The HTR execution was halted") | |
stop_htr_button.click( | |
fn=stop_function, | |
inputs=None, | |
outputs=None, | |
queue=False, | |
api_name=False, | |
# cancels=[htr_event_click_event], | |
) | |
run_image_visualizer_button.click( | |
fn=fast_track.visualize_image_viewer, | |
inputs=fast_track_input_region_image, | |
outputs=[fast_track_output_image, text_polygon_dict], | |
queue=False, | |
api_name=False, | |
) | |
fast_track_output_image.select( | |
fast_track.get_text_from_coords, | |
inputs=text_polygon_dict, | |
outputs=selection_text_from_image_viewer, | |
queue=False, | |
api_name=False, | |
) | |
SECRET_KEY = os.environ.get("HUB_TOKEN", False) | |
if SECRET_KEY: | |
htr_pipeline_button.click( | |
fn=TrafficDataHandler.store_metric_data, | |
inputs=htr_pipeline_button_var, | |
) | |