Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from app_utils import annotate_planogram_compliance, do_sorting, xml_to_csv | |
from inference import run | |
import json | |
import os | |
from tempfile import NamedTemporaryFile | |
# Target names list | |
target_names = [ | |
"Bottle,100PLUS ACTIVE 1.5L", | |
"Bottle,100PLUS ACTIVE 500ML", | |
"Bottle,100PLUS LEMON LIME 1.5L", | |
# Add all other target names here | |
] | |
# Define the function to run planogram compliance check | |
def planogram_compliance_check(planogram_image, master_planogram_image, annotation_file): | |
# Convert uploaded images to numpy arrays | |
planogram_img = np.array(planogram_image) | |
master_planogram_img = np.array(master_planogram_image) | |
# Perform inference on planogram image | |
result_list = run( | |
weights="base_line_best_model_exp5.pt", | |
source=planogram_img, | |
imgsz=[640, 640], | |
conf_thres=0.6, | |
iou_thres=0.6, | |
) | |
# Load annotation file and convert to DataFrame | |
if annotation_file is not None: | |
annotation_df = xml_to_csv(annotation_file) | |
sorted_xml_df = do_sorting(annotation_df) | |
else: | |
sorted_xml_df = None | |
# Run planogram compliance check | |
compliance_score, annotated_image = run_compliance_check( | |
planogram_img, master_planogram_img, sorted_xml_df, result_list | |
) | |
return compliance_score, annotated_image | |
def run_compliance_check(planogram_img, master_planogram_img, sorted_xml_df, result_list): | |
# Placeholder for actual score calculation | |
compliance_score = 0.0 | |
# Placeholder for annotated image | |
annotated_image = planogram_img.copy() | |
if sorted_xml_df is not None: | |
annotate_df = sorted_xml_df[["xmin", "ymin", "xmax", "ymax", "line_number", "cls"]].astype(int) | |
else: | |
annotate_df = None | |
mask = master_table != non_null_product | |
m_detected_table = np.ma.masked_array(master_table, mask=mask) | |
m_annotated_table = np.ma.masked_array(detected_table, mask=mask) | |
# wrong_indexes = np.ravel_multi_index(master_table*mask != detected_table*mask, master_table.shape) | |
wrong_indexes = np.where(master_table != detected_table) | |
correct_indexes = np.where(master_table == detected_table) | |
# Annotate planogram compliance on the image | |
annotated_image = annotate_planogram_compliance( | |
annotated_image, annotate_df, correct_indexes, wrong_indexes, target_names | |
) | |
# Calculate compliance score | |
correct_matches = (np.ma.masked_equal(master_table, non_null_product) == detected_table).sum() | |
total_products = (master_table != non_null_product).sum() | |
if total_products != 0: | |
compliance_score = correct_matches / total_products | |
return compliance_score, annotated_image | |
# Gradio interface | |
planogram_check_interface = gr.Interface( | |
fn=planogram_compliance_check, | |
inputs=[ | |
gr.inputs.Image(label="Planogram Image"), | |
gr.inputs.Image(label="Master Planogram Image"), | |
gr.inputs.Dataframe(label="Annotation File (XML)") | |
], | |
outputs=[ | |
gr.outputs.Textbox(label="Compliance Score"), | |
gr.outputs.Image(label="Annotated Planogram Image"), | |
], | |
title="Planogram Compliance Checker", | |
description="Upload planogram image, master planogram image, and annotation file (if available) to check compliance." | |
) | |
# Launch the interface | |
planogram_check_interface.launch() | |