Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from htrflow.volume.volume import Collection | |
from htrflow.utils.draw import draw_polygons | |
from htrflow.utils import imgproc | |
import time | |
from htrflow.results import Segment | |
def load_visualize_state_from_submit(col: Collection, progress): | |
results = [] | |
time.sleep(1) | |
total_steps = len(col.pages) | |
for page_idx, page_node in enumerate(col): | |
page_image = page_node.image.copy() | |
progress((page_idx + 1) / total_steps, desc=f"Running Visualizer") | |
lines = list(page_node.traverse(lambda node: node.is_line())) | |
recog_conf_values = { | |
i: list(zip(tr.texts, tr.scores)) if (tr := ln.text_result) else [] | |
for i, ln in enumerate(lines) | |
} | |
recog_df = pd.DataFrame( | |
[ | |
{"Transcription": text, "Confidence Score": f"{score:.4f}"} | |
for values in recog_conf_values.values() | |
for text, score in values | |
] | |
) | |
line_polygons = [] | |
line_crops = [] | |
for ln in lines: | |
seg: Segment = ln.data.get("segment") | |
if not seg: | |
continue | |
polygon = ( | |
seg.polygon.move(page_node.coord) if page_node.coord else seg.polygon | |
) | |
bbox = seg.bbox.move(page_node.coord) if page_node.coord else seg.bbox | |
cropped_line_img = imgproc.crop(page_image, bbox) | |
cropped_line_img = np.clip(cropped_line_img, 0, 255).astype(np.uint8) | |
line_crops.append(cropped_line_img) | |
if polygon is not None: | |
line_polygons.append(polygon) | |
annotated_image = draw_polygons(page_image, line_polygons) | |
annotated_page_node = np.clip(annotated_image, 0, 255).astype(np.uint8) | |
results.append( | |
{ | |
"page_image": page_node, | |
"annotated_page_node": annotated_page_node, | |
"line_crops": line_crops, | |
"recog_conf_values": recog_df, | |
} | |
) | |
return results | |
with gr.Blocks() as visualizer: | |
with gr.Column(variant="panel"): | |
with gr.Row(): | |
collection_viz_state = gr.State() | |
result_collection_viz_state = gr.State() | |
with gr.Column(): | |
viz_image_gallery = gr.Gallery( | |
file_types=["image"], | |
label="Visualized Images from HTRflow", | |
interactive=False, | |
height=400, | |
object_fit="cover", | |
columns=5, | |
preview=True, | |
) | |
visualize_button = gr.Button( | |
"Visualize", scale=0, min_width=200, variant="secondary" | |
) | |
progress_bar = gr.Textbox(visible=False, show_label=False) | |
with gr.Column(): | |
cropped_image_gallery = gr.Gallery( | |
interactive=False, | |
preview=True, | |
label="Cropped Polygons", | |
height=200, | |
) | |
df_for_cropped_images = gr.Dataframe( | |
label="Cropped Transcriptions", | |
headers=["Transcription", "Confidence Score"], | |
interactive=False, | |
) | |
def on_visualize_button_clicked(collection_viz, progress=gr.Progress()): | |
""" | |
This function: | |
- Receives the collection (collection_viz). | |
- Processes it into 'results' (list of dicts with annotated_page_node, line_crops, dataframe). | |
- Returns: | |
1) 'results' as state | |
2) List of annotated_page_node images (one per page) to populate viz_image_gallery | |
""" | |
if not collection_viz: | |
return None, [] | |
results = load_visualize_state_from_submit(collection_viz, progress) | |
annotated_images = [r["annotated_page_node"] for r in results] | |
return results, annotated_images, gr.skip() | |
visualize_button.click(lambda: gr.update(visible=True), outputs=progress_bar).then( | |
fn=on_visualize_button_clicked, | |
inputs=collection_viz_state, | |
outputs=[result_collection_viz_state, viz_image_gallery, progress_bar], | |
).then(lambda: gr.update(visible=False), outputs=progress_bar) | |
def update_c_gallery_and_dataframe(results): | |
selected = results[0] | |
return selected["line_crops"], selected["recog_conf_values"] | |
def on_dataframe_select(evt: gr.SelectData, results): | |
""" | |
evt.index => the index of the selected image in the gallery | |
results => the state object from result_collection_viz_state | |
Return the line crops and the recognized text for that index. | |
""" | |
if results is None or evt.index is None: | |
return [], pd.DataFrame(columns=["Transcription", "Confidence Score"]) | |
idx = evt.index | |
selected = results[idx] | |
return selected["line_crops"], selected["recog_conf_values"] | |
def on_dataframe_select(evt: gr.SelectData): | |
return gr.update(selected_index=evt.index[0]) | |
def return_image_from_gallery(df, evt: gr.SelectData): | |
selected_index = evt.index | |
def highlight_row(row): | |
return [ | |
( | |
"border: 1px solid blue; font-weight: bold" | |
if row.name == selected_index | |
else "" | |
) | |
for _ in row | |
] | |
styler = df.style.apply(highlight_row, axis=1) | |
return styler | |