htr_demo / app /tabs /visualizer.py
amlpai04
forgot assets catalog
198f8ab
raw
history blame
6.1 kB
import gradio as gr
import pandas as pd
import numpy as np
from htrflow.volume.volume import Collection
from htrflow.utils.draw import draw_polygons
from htrflow.utils import imgproc
import time
from htrflow.results import Segment
def load_visualize_state_from_submit(col: Collection, progress):
results = []
time.sleep(1)
total_steps = len(col.pages)
for page_idx, page_node in enumerate(col):
page_image = page_node.image.copy()
progress((page_idx + 1) / total_steps, desc=f"Running Visualizer")
lines = list(page_node.traverse(lambda node: node.is_line()))
recog_conf_values = {
i: list(zip(tr.texts, tr.scores)) if (tr := ln.text_result) else []
for i, ln in enumerate(lines)
}
recog_df = pd.DataFrame(
[
{"Transcription": text, "Confidence Score": f"{score:.4f}"}
for values in recog_conf_values.values()
for text, score in values
]
)
line_polygons = []
line_crops = []
for ln in lines:
seg: Segment = ln.data.get("segment")
if not seg:
continue
polygon = (
seg.polygon.move(page_node.coord) if page_node.coord else seg.polygon
)
bbox = seg.bbox.move(page_node.coord) if page_node.coord else seg.bbox
cropped_line_img = imgproc.crop(page_image, bbox)
cropped_line_img = np.clip(cropped_line_img, 0, 255).astype(np.uint8)
line_crops.append(cropped_line_img)
if polygon is not None:
line_polygons.append(polygon)
annotated_image = draw_polygons(page_image, line_polygons)
annotated_page_node = np.clip(annotated_image, 0, 255).astype(np.uint8)
results.append(
{
"page_image": page_node,
"annotated_page_node": annotated_page_node,
"line_crops": line_crops,
"recog_conf_values": recog_df,
}
)
return results
with gr.Blocks() as visualizer:
with gr.Column(variant="panel"):
with gr.Row():
collection_viz_state = gr.State()
result_collection_viz_state = gr.State()
with gr.Column():
viz_image_gallery = gr.Gallery(
file_types=["image"],
label="Visualized Images from HTRflow",
interactive=False,
height=400,
object_fit="cover",
columns=5,
preview=True,
)
visualize_button = gr.Button(
"Visualize", scale=0, min_width=200, variant="secondary"
)
progress_bar = gr.Textbox(visible=False, show_label=False)
with gr.Column():
cropped_image_gallery = gr.Gallery(
interactive=False,
preview=True,
label="Cropped Polygons",
height=200,
)
df_for_cropped_images = gr.Dataframe(
label="Cropped Transcriptions",
headers=["Transcription", "Confidence Score"],
interactive=False,
)
def on_visualize_button_clicked(collection_viz, progress=gr.Progress()):
"""
This function:
- Receives the collection (collection_viz).
- Processes it into 'results' (list of dicts with annotated_page_node, line_crops, dataframe).
- Returns:
1) 'results' as state
2) List of annotated_page_node images (one per page) to populate viz_image_gallery
"""
if not collection_viz:
return None, []
results = load_visualize_state_from_submit(collection_viz, progress)
annotated_images = [r["annotated_page_node"] for r in results]
return results, annotated_images, gr.skip()
visualize_button.click(lambda: gr.update(visible=True), outputs=progress_bar).then(
fn=on_visualize_button_clicked,
inputs=collection_viz_state,
outputs=[result_collection_viz_state, viz_image_gallery, progress_bar],
).then(lambda: gr.update(visible=False), outputs=progress_bar)
@viz_image_gallery.change(
inputs=result_collection_viz_state,
outputs=[cropped_image_gallery, df_for_cropped_images],
)
def update_c_gallery_and_dataframe(results):
selected = results[0]
return selected["line_crops"], selected["recog_conf_values"]
@viz_image_gallery.select(
inputs=result_collection_viz_state,
outputs=[cropped_image_gallery, df_for_cropped_images],
)
def on_dataframe_select(evt: gr.SelectData, results):
"""
evt.index => the index of the selected image in the gallery
results => the state object from result_collection_viz_state
Return the line crops and the recognized text for that index.
"""
if results is None or evt.index is None:
return [], pd.DataFrame(columns=["Transcription", "Confidence Score"])
idx = evt.index
selected = results[idx]
return selected["line_crops"], selected["recog_conf_values"]
@df_for_cropped_images.select(
outputs=[cropped_image_gallery],
)
def on_dataframe_select(evt: gr.SelectData):
return gr.update(selected_index=evt.index[0])
@cropped_image_gallery.select(
inputs=df_for_cropped_images, outputs=df_for_cropped_images
)
def return_image_from_gallery(df, evt: gr.SelectData):
selected_index = evt.index
def highlight_row(row):
return [
(
"border: 1px solid blue; font-weight: bold"
if row.name == selected_index
else ""
)
for _ in row
]
styler = df.style.apply(highlight_row, axis=1)
return styler