Spaces:
Running
Running
import streamlit as st | |
import streamlit.components.v1 as components | |
from PIL import Image, ImageEnhance | |
import torch | |
from torchvision.transforms import functional as F | |
import gc | |
import psutil | |
import copy | |
import xml.etree.ElementTree as ET | |
import numpy as np | |
from xml.dom import minidom | |
from pathlib import Path | |
import gdown | |
from modules.htlm_webpage import display_bpmn_xml | |
from modules.OCR import text_prediction, filter_text, mapping_text, rescale | |
from modules.utils import class_dict, arrow_dict, object_dict, find_closest_object | |
from modules.toXML import calculate_pool_bounds, add_diagram_elements, create_bpmn_object, create_flow_element | |
from modules.display import draw_stream | |
from modules.eval import full_prediction | |
from modules.train import get_faster_rcnn_model, get_arrow_model | |
from streamlit_image_comparison import image_comparison | |
from streamlit_cropper import st_cropper | |
from streamlit_drawable_canvas import st_canvas | |
from streamlit_image_select import image_select | |
def get_memory_usage(): | |
process = psutil.Process() | |
mem_info = process.memory_info() | |
return mem_info.rss / (1024 ** 2) # Return memory usage in MB | |
def clear_memory(): | |
st.session_state.clear() | |
gc.collect() | |
# Function to read XML content from a file | |
def read_xml_file(filepath): | |
""" Read XML content from a file """ | |
with open(filepath, 'r', encoding='utf-8') as file: | |
return file.read() | |
# Function to modify bounding box positions based on the given sizes | |
def modif_box_pos(pred, size): | |
modified_pred = copy.deepcopy(pred) # Make a deep copy of the prediction | |
for i, (x1, y1, x2, y2) in enumerate(modified_pred['boxes']): | |
center = [(x1 + x2) / 2, (y1 + y2) / 2] | |
label = class_dict[modified_pred['labels'][i]] | |
if label in size: | |
modified_pred['boxes'][i] = [center[0] - size[label][0] / 2, center[1] - size[label][1] / 2, center[0] + size[label][0] / 2, center[1] + size[label][1] / 2] | |
return modified_pred['boxes'] | |
# Function to create a BPMN XML file from prediction results | |
def create_XML(full_pred, text_mapping, scale): | |
namespaces = { | |
'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL', | |
'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI', | |
'di': 'http://www.omg.org/spec/DD/20100524/DI', | |
'dc': 'http://www.omg.org/spec/DD/20100524/DC', | |
'xsi': 'http://www.w3.org/2001/XMLSchema-instance' | |
} | |
size_elements = { | |
'start': (43.2, 43.2), | |
'task': (120, 96), | |
'message': (43.2, 43.2), | |
'messageEvent': (43.2, 43.2), | |
'end': (43.2, 43.2), | |
'exclusiveGateway': (60, 60), | |
'event': (43.2, 43.2), | |
'parallelGateway': (60, 60), | |
'dataObject': (48, 72), | |
'dataStore': (72, 72), | |
'subProcess': (144, 108), | |
'eventBasedGateway': (60, 60), | |
'timerEvent': (48, 48), | |
} | |
definitions = ET.Element('bpmn:definitions', { | |
'xmlns:xsi': namespaces['xsi'], | |
'xmlns:bpmn': namespaces['bpmn'], | |
'xmlns:bpmndi': namespaces['bpmndi'], | |
'xmlns:di': namespaces['di'], | |
'xmlns:dc': namespaces['dc'], | |
'targetNamespace': "http://example.bpmn.com", | |
'id': "simpleExample" | |
}) | |
#modify the boxes positions | |
old_boxes = copy.deepcopy(full_pred) | |
full_pred['boxes'] = modif_box_pos(full_pred, size_elements) | |
# Create BPMN collaboration element | |
collaboration = ET.SubElement(definitions, 'bpmn:collaboration', id='collaboration_1') | |
# Create BPMN process elements | |
process = [] | |
for idx in range(len(full_pred['pool_dict'].items())): | |
process_id = f'process_{idx+1}' | |
process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]])) | |
bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1') | |
bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1') | |
full_pred['boxes'] = rescale(scale, full_pred['boxes']) | |
# Add diagram elements for each pool | |
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()): | |
pool_id = f'participant_{idx+1}' | |
pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]]) | |
# Calculate the bounding box for the pool | |
if len(keep_elements) == 0: | |
min_x, min_y, max_x, max_y = full_pred['boxes'][pool_index] | |
pool_width = max_x - min_x | |
pool_height = max_y - min_y | |
else: | |
min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred, keep_elements, size_elements) | |
pool_width = max_x - min_x + 100 # Adding padding | |
pool_height = max_y - min_y + 100 # Adding padding | |
add_diagram_elements(bpmnplane, pool_id, min_x - 50, min_y - 50, pool_width, pool_height) | |
# Create BPMN elements for each pool | |
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()): | |
create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements) | |
# Create message flow elements | |
message_flows = [i for i, label in enumerate(full_pred['labels']) if class_dict[label] == 'messageFlow'] | |
for idx in message_flows: | |
create_flow_element(bpmnplane, text_mapping, idx, size_elements, full_pred, collaboration, message=True) | |
# Create sequence flow elements | |
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()): | |
for i in keep_elements: | |
if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'): | |
create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False) | |
# Generate pretty XML string | |
tree = ET.ElementTree(definitions) | |
rough_string = ET.tostring(definitions, 'utf-8') | |
reparsed = minidom.parseString(rough_string) | |
pretty_xml_as_string = reparsed.toprettyxml(indent=" ") | |
full_pred['boxes'] = rescale(1/scale, full_pred['boxes']) | |
full_pred['boxes'] = old_boxes | |
return pretty_xml_as_string | |
# Function to load the models only once and use session state to keep track of it | |
def load_models(): | |
with st.spinner('Loading model...'): | |
model_object = get_faster_rcnn_model(len(object_dict)) | |
model_arrow = get_arrow_model(len(arrow_dict),2) | |
url_arrow = 'https://drive.google.com/uc?id=1xwfvo7BgDWz-1jAiJC1DCF0Wp8YlFNWt' | |
url_object = 'https://drive.google.com/uc?id=1GiM8xOXG6M6r8J9HTOeMJz9NKu7iumZi' | |
# Define paths to save models | |
output_arrow = 'model_arrow.pth' | |
output_object = 'model_object.pth' | |
# Download models using gdown | |
if not Path(output_arrow).exists(): | |
# Download models using gdown | |
gdown.download(url_arrow, output_arrow, quiet=False) | |
else: | |
print('Model arrow downloaded from local') | |
if not Path(output_object).exists(): | |
gdown.download(url_object, output_object, quiet=False) | |
else: | |
print('Model object downloaded from local') | |
# Load models | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model_arrow.load_state_dict(torch.load(output_arrow, map_location=device)) | |
model_object.load_state_dict(torch.load(output_object, map_location=device)) | |
st.session_state.model_loaded = True | |
st.session_state.model_arrow = model_arrow | |
st.session_state.model_object = model_object | |
# Function to prepare the image for processing | |
def prepare_image(image, pad=True, new_size=(1333, 1333)): | |
original_size = image.size | |
# Calculate scale to fit the new size while maintaining aspect ratio | |
scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1]) | |
new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale)) | |
# Resize image to new scaled size | |
image = F.resize(image, (new_scaled_size[1], new_scaled_size[0])) | |
if pad: | |
enhancer = ImageEnhance.Brightness(image) | |
image = enhancer.enhance(1.5) # Adjust the brightness if necessary | |
# Pad the resized image to make it exactly the desired size | |
padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]] | |
image = F.pad(image, padding, fill=200, padding_mode='edge') | |
return new_scaled_size, image | |
# Function to display various options for image annotation | |
def display_options(image, score_threshold): | |
col1, col2, col3, col4, col5 = st.columns(5) | |
with col1: | |
write_class = st.toggle("Write Class", value=True) | |
draw_keypoints = st.toggle("Draw Keypoints", value=True) | |
draw_boxes = st.toggle("Draw Boxes", value=True) | |
with col2: | |
draw_text = st.toggle("Draw Text", value=False) | |
write_text = st.toggle("Write Text", value=False) | |
draw_links = st.toggle("Draw Links", value=False) | |
with col3: | |
write_score = st.toggle("Write Score", value=True) | |
write_idx = st.toggle("Write Index", value=False) | |
with col4: | |
# Define options for the dropdown menu | |
dropdown_options = [list(class_dict.values())[i] for i in range(len(class_dict))] | |
dropdown_options[0] = 'all' | |
selected_option = st.selectbox("Show class", dropdown_options) | |
# Draw the annotated image with selected options | |
annotated_image = draw_stream( | |
np.array(image), prediction=st.session_state.prediction, text_predictions=st.session_state.text_pred, | |
draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text, | |
write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_print=selected_option, | |
score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True | |
) | |
# Display the original and annotated images side by side | |
image_comparison( | |
img1=annotated_image, | |
img2=image, | |
label1="Annotated Image", | |
label2="Original Image", | |
starting_position=99, | |
width=1000, | |
) | |
# Function to perform inference on the uploaded image using the loaded models | |
def perform_inference(model_object, model_arrow, image, score_threshold): | |
_, uploaded_image = prepare_image(image, pad=False) | |
img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))[1]) | |
# Display original image | |
if 'image_placeholder' not in st.session_state: | |
image_placeholder = st.empty() # Create an empty placeholder | |
image_placeholder.image(uploaded_image, caption='Original Image', width=1000) | |
# Prediction | |
_, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=0.5, distance_treshold=30) | |
# Perform OCR on the uploaded image | |
ocr_results = text_prediction(uploaded_image) | |
# Filter and map OCR results to prediction results | |
st.session_state.text_pred = filter_text(ocr_results, threshold=0.5) | |
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=0.5) | |
# Remove the original image display | |
image_placeholder.empty() | |
# Force garbage collection | |
gc.collect() | |
def get_image(uploaded_file): | |
return Image.open(uploaded_file).convert('RGB') | |
def main(): | |
st.set_page_config(layout="wide") | |
# Add your company logo banner | |
st.image("./images/banner.png", use_column_width=True) | |
# Sidebar content | |
st.sidebar.header("This BPMN AI model recognition is proposed by: \n ELCA in collaboration with EPFL.") | |
st.sidebar.subheader("Instructions:") | |
st.sidebar.text("1. Upload you image") | |
st.sidebar.text("2. Crop the image \n (try to put the BPMN diagram \n in the center of the image)") | |
st.sidebar.text("3. Set the score threshold \n for prediction (default is 0.5)") | |
st.sidebar.text("4. Set the scale for the XML file \n (default is 1.0)") | |
st.sidebar.text("5. Click on 'Launch Prediction'") | |
st.sidebar.text("6. You can now see the annotation \n and the BPMN XML result") | |
st.sidebar.text("7. You can modify and download \n the result in right format") | |
st.sidebar.subheader("If there is an error, try to:") | |
st.sidebar.text("1. Change the score threshold") | |
st.sidebar.text("2. Re-crop the image by placing\n the BPMN diagram in the center\n of the image") | |
st.sidebar.text("3. Re-Launch the prediction") | |
st.sidebar.subheader("You can close this sidebar") | |
# Set the title of the app | |
st.title("BPMN recognition by AI demo") | |
# Display current memory usage | |
memory_usage = get_memory_usage() | |
print(f"Current memory usage: {memory_usage:.2f} MB") | |
# Initialize the session state for storing pool bounding boxes | |
if 'pool_bboxes' not in st.session_state: | |
st.session_state.pool_bboxes = [] | |
# Load the models using the defined function | |
if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state: | |
clear_memory() | |
load_models() | |
model_arrow = st.session_state.model_arrow | |
model_object = st.session_state.model_object | |
with st.expander("Use example images"): | |
img_selected = image_select("If you have no image and just want to test the demo, click on one of these images", ["./images/none.jpg", "./images/example1.jpg", "./images/example2.jpg", "./images/example3.jpg", "./images/example4.jpg"], | |
captions=["None", "Example 1", "Example 2", "Example 3", "Example 4"], index=0, use_container_width=False, return_value="original") | |
if img_selected== './images/none.jpg': | |
print('No example image selected') | |
#delete the prediction | |
#if 'prediction' in st.session_state: | |
#del st.session_state['prediction'] | |
img_selected = None | |
#Create the layout for the app | |
col1, col2 = st.columns(2) | |
with col1: | |
# Create a file uploader for the user to upload an image | |
if img_selected is not None: | |
uploaded_file = img_selected | |
else: | |
uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"]) | |
# Display the uploaded image if the user has uploaded an image | |
if uploaded_file is not None: | |
with st.spinner('Waiting for image display...'): | |
original_image = get_image(uploaded_file) | |
col1, col2 = st.columns(2) | |
# Create a cropper to allow the user to crop the image and display the cropped image | |
with col1: | |
cropped_image = st_cropper(original_image, realtime_update=True, box_color='#0000FF', should_resize_image=True, default_coords=(30, original_image.size[0]-30, 30, original_image.size[1]-30)) | |
with col2: | |
st.image(cropped_image, caption="Cropped Image", use_column_width=False, width=500) | |
# Display the options for the user to set the score threshold and scale | |
if cropped_image is not None: | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05) | |
with col2: | |
st.session_state.scale = st.slider("Set scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1) | |
# Launch the prediction when the user clicks the button | |
if st.button("Launch Prediction"): | |
st.session_state.crop_image = cropped_image | |
with st.spinner('Processing...'): | |
perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold) | |
#st.session_state.prediction = modif_box_pos(st.session_state.prediction, object_dict) | |
st.balloons() | |
#else: | |
#delete the prediction | |
#if 'prediction' in st.session_state: | |
#del st.session_state['prediction'] | |
# If the prediction has been made and the user has uploaded an image, display the options for the user to annotate the image | |
if 'prediction' in st.session_state and uploaded_file is not None: | |
with st.spinner('Waiting for result display...'): | |
display_options(st.session_state.crop_image, score_threshold) | |
#if st.session_state.prediction_up==True: | |
with st.spinner('Waiting for BPMN modeler...'): | |
st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.scale) | |
display_bpmn_xml(st.session_state.bpmn_xml) | |
# Force garbage collection after display | |
gc.collect() | |
if __name__ == "__main__": | |
print('Starting the app...') | |
main() | |