machine-learning-ui / views /data_annotation.py
katanaml's picture
Initial commit
6a4662b
raw
history blame
15.6 kB
import streamlit as st
from PIL import Image
import streamlit_nested_layout
from streamlit_sparrow_labeling import st_sparrow_labeling
from streamlit_sparrow_labeling import DataProcessor
import json
import math
import os
class DataAnnotation:
class Model:
pageTitle = "Data Annotation"
img_file = None
rects_file = None
assign_labels_text = "Assign Labels"
text_caption_1 = "Check 'Assign Labels' to enable editing of labels and values, move and resize the boxes to annotate the document."
text_caption_2 = "Add annotations by clicking and dragging on the document, when 'Assign Labels' is unchecked."
labels = ["", "item", "item_price", "subtotal", "tax", "total", "date_issued", "due_date", "invoice_number",
"amount_due", "deposit_due"]
selected_field = "Selected Field: "
save_text = "Save"
saved_text = "Saved!"
subheader_1 = "Select"
subheader_2 = "Upload"
annotation_text = "Annotation"
no_annotation_file = "No annotation file selected"
no_annotation_mapping = "Please annotate the document. Uncheck 'Assign Labels' and draw new annotations"
download_text = "Download"
download_hint = "Download the annotated structure in JSON format"
annotation_selection_help = "Select an annotation file to load"
upload_help = "Upload a file to annotate"
upload_button_text = "Upload"
upload_button_text_desc = "Choose a file"
assign_labels_text = "Assign Labels"
assign_labels_help = "Check to enable editing of labels and values"
save_help = "Save the annotations"
error_text = "Value is too long. Please shorten it."
def view(self, model, ui_width, device_type, device_width):
with st.sidebar:
st.markdown("---")
st.subheader(model.subheader_1)
placeholder_upload = st.empty()
file_names = self.get_existing_file_names('docs/image/')
if 'annotation_index' not in st.session_state:
st.session_state['annotation_index'] = 0
annotation_index = 0
else:
annotation_index = st.session_state['annotation_index']
annotation_selection = placeholder_upload.selectbox(model.annotation_text, file_names,
index=annotation_index,
help=model.annotation_selection_help)
annotation_index = self.get_annotation_index(annotation_selection, file_names)
st.session_state['annotation_index'] = annotation_index
file_extension = self.get_file_extension(annotation_selection, 'docs/image/')
model.img_file = f"docs/image/{annotation_selection}" + file_extension
model.rects_file = f"docs/json/{annotation_selection}.json"
st.subheader(model.subheader_2)
with st.form("upload-form", clear_on_submit=True):
uploaded_file = st.file_uploader(model.upload_button_text_desc, accept_multiple_files=False,
type=['png', 'jpg', 'jpeg'],
help=model.upload_help)
submitted = st.form_submit_button(model.upload_button_text, disabled=True)
if submitted and uploaded_file is not None:
ret = self.upload_file(uploaded_file)
if ret is not False:
file_names = self.get_existing_file_names('docs/image/')
annotation_index = self.get_annotation_index(annotation_selection, file_names)
annotation_selection = placeholder_upload.selectbox(model.annotation_text, file_names,
index=annotation_index,
help=model.annotation_selection_help)
st.session_state['annotation_index'] = annotation_index
st.title(model.pageTitle + " - " + annotation_selection)
if model.img_file is None:
st.caption(model.no_annotation_file)
return
saved_state = self.fetch_annotations(model.rects_file)
assign_labels = st.checkbox(model.assign_labels_text, True, help=model.assign_labels_help)
mode = "transform" if assign_labels else "rect"
docImg = Image.open(model.img_file)
data_processor = DataProcessor()
with st.container():
doc_height = saved_state['meta']['image_size']['height']
doc_width = saved_state['meta']['image_size']['width']
canvas_width, number_of_columns = self.canvas_available_width(ui_width, doc_width, device_type,
device_width)
if number_of_columns > 1:
col1, col2 = st.columns([number_of_columns, 10 - number_of_columns])
with col1:
result_rects = self.render_doc(model, docImg, saved_state, mode, canvas_width, doc_height, doc_width)
with col2:
self.render_form(model, result_rects, data_processor, number_of_columns, annotation_selection)
else:
result_rects = self.render_doc(model, docImg, saved_state, mode, canvas_width, doc_height, doc_width)
self.render_form(model, result_rects, data_processor, number_of_columns, annotation_selection)
def render_doc(self, model, docImg, saved_state, mode, canvas_width, doc_height, doc_width):
with st.container():
height = 1296
width = 864
result_rects = st_sparrow_labeling(
fill_color="rgba(0, 151, 255, 0.3)",
stroke_width=2,
stroke_color="rgba(0, 50, 255, 0.7)",
background_image=docImg,
initial_rects=saved_state,
height=height,
width=width,
drawing_mode=mode,
display_toolbar=True,
update_streamlit=True,
canvas_width=canvas_width,
doc_height=doc_height,
doc_width=doc_width,
image_rescale=True,
key="doc_annotation" + model.img_file
)
st.caption(model.text_caption_1)
st.caption(model.text_caption_2)
return result_rects
def render_form(self, model, result_rects, data_processor, number_of_columns, annotation_selection):
with st.container():
if result_rects is not None:
if len(result_rects.rects_data['words']) == 0:
st.caption(model.no_annotation_mapping)
return
else:
with open(model.rects_file, 'rb') as file:
st.download_button(label=model.download_text,
data=file,
file_name=annotation_selection + ".json",
mime='application/json',
help=model.download_hint)
with st.form(key="fields_form"):
if result_rects.current_rect_index is not None and result_rects.current_rect_index != -1:
st.write(model.selected_field,
result_rects.rects_data['words'][result_rects.current_rect_index]['value'])
st.markdown("---")
if number_of_columns == 4:
self.render_form_wide(result_rects.rects_data['words'], model.labels, result_rects,
data_processor)
elif number_of_columns == 5:
self.render_form_avg(result_rects.rects_data['words'], model.labels, result_rects,
data_processor)
elif number_of_columns == 6:
self.render_form_narrow(result_rects.rects_data['words'], model.labels, result_rects,
data_processor)
else:
self.render_form_mobile(result_rects.rects_data['words'], model.labels, result_rects,
data_processor)
submit = st.form_submit_button(model.save_text, type="primary", help=model.save_help, disabled=True)
if submit:
for word in result_rects.rects_data['words']:
if len(word['value']) > 100:
st.error(model.error_text)
return
with open(model.rects_file, "w") as f:
json.dump(result_rects.rects_data, f, indent=2)
st.session_state[model.rects_file] = result_rects.rects_data
# st.write(model.saved_text)
st.experimental_rerun()
def render_form_wide(self, words, labels, result_rects, data_processor):
col1_form, col2_form, col3_form, col4_form = st.columns([1, 1, 1, 1])
num_rows = math.ceil(len(words) / 4)
for i, rect in enumerate(words):
if i < num_rows:
with col1_form:
self.render_form_element(rect, labels, i, result_rects, data_processor)
elif i < num_rows * 2:
with col2_form:
self.render_form_element(rect, labels, i, result_rects, data_processor)
elif i < num_rows * 3:
with col3_form:
self.render_form_element(rect, labels, i, result_rects, data_processor)
else:
with col4_form:
self.render_form_element(rect, labels, i, result_rects, data_processor)
def render_form_avg(self, words, labels, result_rects, data_processor):
col1_form, col2_form, col3_form = st.columns([1, 1, 1])
num_rows = math.ceil(len(words) / 3)
for i, rect in enumerate(words):
if i < num_rows:
with col1_form:
self.render_form_element(rect, labels, i, result_rects, data_processor)
elif i < num_rows * 2:
with col2_form:
self.render_form_element(rect, labels, i, result_rects, data_processor)
else:
with col3_form:
self.render_form_element(rect, labels, i, result_rects, data_processor)
def render_form_narrow(self, words, labels, result_rects, data_processor):
col1_form, col2_form = st.columns([1, 1])
num_rows = math.ceil(len(words) / 2)
for i, rect in enumerate(words):
if i < num_rows:
with col1_form:
self.render_form_element(rect, labels, i, result_rects, data_processor)
else:
with col2_form:
self.render_form_element(rect, labels, i, result_rects, data_processor)
def render_form_mobile(self, words, labels, result_rects, data_processor):
for i, rect in enumerate(words):
self.render_form_element(rect, labels, i, result_rects, data_processor)
def render_form_element(self, rect, labels, i, result_rects, data_processor):
default_index = 0
if rect['label']:
default_index = labels.index(rect['label'])
value = st.text_input("Value", rect['value'], key=f"field_value_{i}",
disabled=False if i == result_rects.current_rect_index else True)
label = st.selectbox("Label", labels, key=f"label_{i}", index=default_index,
disabled=False if i == result_rects.current_rect_index else True)
st.markdown("---")
data_processor.update_rect_data(result_rects.rects_data, i, value, label)
def canvas_available_width(self, ui_width, doc_width, device_type, device_width):
doc_width_pct = (doc_width * 100) / ui_width
if doc_width_pct < 45:
canvas_width_pct = 37
elif doc_width_pct < 55:
canvas_width_pct = 49
else:
canvas_width_pct = 65
if ui_width > 700 and canvas_width_pct == 37 and device_type == "desktop":
return math.floor(canvas_width_pct * ui_width / 100), 4
elif ui_width > 700 and canvas_width_pct == 49 and device_type == "desktop":
return math.floor(canvas_width_pct * ui_width / 100), 5
elif ui_width > 700 and canvas_width_pct == 65 and device_type == "desktop":
return math.floor(canvas_width_pct * ui_width / 100), 6
else:
if device_type == "desktop":
ui_width = device_width - math.floor((device_width * 22) / 100)
elif device_type == "mobile":
ui_width = device_width - math.floor((device_width * 13) / 100)
return ui_width, 1
def fetch_annotations(self, rects_file):
if rects_file not in st.session_state:
with open(rects_file, "r") as f:
saved_state = json.load(f)
st.session_state[rects_file] = saved_state
else:
saved_state = st.session_state[rects_file]
return saved_state
def upload_file(self, uploaded_file):
if uploaded_file is not None:
if os.path.exists(os.path.join("docs/image/", uploaded_file.name)):
st.write("File already exists")
return False
if len(uploaded_file.name) > 100:
st.write("File name too long")
return False
with open(os.path.join("docs/image/", uploaded_file.name), "wb") as f:
f.write(uploaded_file.getbuffer())
img_file = Image.open(os.path.join("docs/image/", uploaded_file.name))
annotations_json = {
"meta": {
"version": "v0.1",
"split": "train",
"image_id": len(self.get_existing_file_names("docs/image/")),
"image_size": {
"width": img_file.width,
"height": img_file.height
}
},
"words": []
}
file_name = uploaded_file.name.split(".")[0]
with open(os.path.join("docs/json/", file_name + ".json"), "w") as f:
json.dump(annotations_json, f, indent=2)
st.write("File uploaded successfully")
def get_existing_file_names(self, dir_name):
# get ordered list of files without file extension, excluding hidden files
return sorted([os.path.splitext(f)[0] for f in os.listdir(dir_name) if not f.startswith('.')])
def get_file_extension(self, file_name, dir_name):
# get list of files, excluding hidden files
files = [f for f in os.listdir(dir_name) if not f.startswith('.')]
for f in files:
if file_name is not None and os.path.splitext(f)[0] == file_name:
return os.path.splitext(f)[1]
def get_annotation_index(self, file, files_list):
return files_list.index(file)