Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,099 Bytes
c115883 2a0d582 c115883 2a0d582 198f8ab 2a0d582 198f8ab 2a0d582 198f8ab 2a0d582 c115883 2a0d582 c115883 2a0d582 c115883 2a0d582 c115883 2a0d582 c115883 2a0d582 c115883 2a0d582 c115883 2a0d582 c115883 2a0d582 c115883 2a0d582 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import gradio as gr
import pandas as pd
import numpy as np
from htrflow.volume.volume import Collection
from htrflow.utils.draw import draw_polygons
from htrflow.utils import imgproc
import time
from htrflow.results import Segment
def load_visualize_state_from_submit(col: Collection, progress):
results = []
time.sleep(1)
total_steps = len(col.pages)
for page_idx, page_node in enumerate(col):
page_image = page_node.image.copy()
progress((page_idx + 1) / total_steps, desc=f"Running Visualizer")
lines = list(page_node.traverse(lambda node: node.is_line()))
recog_conf_values = {
i: list(zip(tr.texts, tr.scores)) if (tr := ln.text_result) else []
for i, ln in enumerate(lines)
}
recog_df = pd.DataFrame(
[
{"Transcription": text, "Confidence Score": f"{score:.4f}"}
for values in recog_conf_values.values()
for text, score in values
]
)
line_polygons = []
line_crops = []
for ln in lines:
seg: Segment = ln.data.get("segment")
if not seg:
continue
polygon = (
seg.polygon.move(page_node.coord) if page_node.coord else seg.polygon
)
bbox = seg.bbox.move(page_node.coord) if page_node.coord else seg.bbox
cropped_line_img = imgproc.crop(page_image, bbox)
cropped_line_img = np.clip(cropped_line_img, 0, 255).astype(np.uint8)
line_crops.append(cropped_line_img)
if polygon is not None:
line_polygons.append(polygon)
annotated_image = draw_polygons(page_image, line_polygons)
annotated_page_node = np.clip(annotated_image, 0, 255).astype(np.uint8)
results.append(
{
"page_image": page_node,
"annotated_page_node": annotated_page_node,
"line_crops": line_crops,
"recog_conf_values": recog_df,
}
)
return results
with gr.Blocks() as visualizer:
with gr.Column(variant="panel"):
with gr.Row():
collection_viz_state = gr.State()
result_collection_viz_state = gr.State()
with gr.Column():
viz_image_gallery = gr.Gallery(
file_types=["image"],
label="Visualized Images from HTRflow",
interactive=False,
height=400,
object_fit="cover",
columns=5,
preview=True,
)
visualize_button = gr.Button(
"Visualize", scale=0, min_width=200, variant="secondary"
)
progress_bar = gr.Textbox(visible=False, show_label=False)
with gr.Column():
cropped_image_gallery = gr.Gallery(
interactive=False,
preview=True,
label="Cropped Polygons",
height=200,
)
df_for_cropped_images = gr.Dataframe(
label="Cropped Transcriptions",
headers=["Transcription", "Confidence Score"],
interactive=False,
)
def on_visualize_button_clicked(collection_viz, progress=gr.Progress()):
"""
This function:
- Receives the collection (collection_viz).
- Processes it into 'results' (list of dicts with annotated_page_node, line_crops, dataframe).
- Returns:
1) 'results' as state
2) List of annotated_page_node images (one per page) to populate viz_image_gallery
"""
if not collection_viz:
return None, []
results = load_visualize_state_from_submit(collection_viz, progress)
annotated_images = [r["annotated_page_node"] for r in results]
return results, annotated_images, gr.skip()
visualize_button.click(lambda: gr.update(visible=True), outputs=progress_bar).then(
fn=on_visualize_button_clicked,
inputs=collection_viz_state,
outputs=[result_collection_viz_state, viz_image_gallery, progress_bar],
).then(lambda: gr.update(visible=False), outputs=progress_bar)
@viz_image_gallery.change(
inputs=result_collection_viz_state,
outputs=[cropped_image_gallery, df_for_cropped_images],
)
def update_c_gallery_and_dataframe(results):
selected = results[0]
return selected["line_crops"], selected["recog_conf_values"]
@viz_image_gallery.select(
inputs=result_collection_viz_state,
outputs=[cropped_image_gallery, df_for_cropped_images],
)
def on_dataframe_select(evt: gr.SelectData, results):
"""
evt.index => the index of the selected image in the gallery
results => the state object from result_collection_viz_state
Return the line crops and the recognized text for that index.
"""
if results is None or evt.index is None:
return [], pd.DataFrame(columns=["Transcription", "Confidence Score"])
idx = evt.index
selected = results[idx]
return selected["line_crops"], selected["recog_conf_values"]
@df_for_cropped_images.select(
outputs=[cropped_image_gallery],
)
def on_dataframe_select(evt: gr.SelectData):
return gr.update(selected_index=evt.index[0])
@cropped_image_gallery.select(
inputs=df_for_cropped_images, outputs=df_for_cropped_images
)
def return_image_from_gallery(df, evt: gr.SelectData):
selected_index = evt.index
def highlight_row(row):
return [
(
"border: 1px solid blue; font-weight: bold"
if row.name == selected_index
else ""
)
for _ in row
]
styler = df.style.apply(highlight_row, axis=1)
return styler
|