Spaces:
Sleeping
Sleeping
import os | |
import base64 | |
import json | |
import streamlit as st | |
from streamlit_option_menu import option_menu | |
from streamlit_authenticator import Authenticate | |
import yaml | |
from yaml.loader import SafeLoader | |
import pandas as pd | |
from PIL import Image | |
import numpy as np | |
import torch | |
import cv2 | |
from transformers import AutoImageProcessor, AutoModelForObjectDetection | |
import bbox_visualizer as bbv | |
from st_clickable_images import clickable_images | |
from glob import glob | |
MODEL_PATH = "pankaj-munde/FScout_v0.2" | |
# image_dir = "./Data/images/" | |
detr_preprocessor = AutoImageProcessor.from_pretrained(MODEL_PATH, token=st.secrets["HF_TOKEN"]) | |
detr_model = AutoModelForObjectDetection.from_pretrained(MODEL_PATH, token=st.secrets["HF_TOKEN"]) | |
colors = [[236, 112, 99], [165, 105, 189], [ 225, 9, 232], [ 255, 38, 8 ], [ 247, 249, 249 ], [170, 183, 184 ], [ 247, 249, 249 ], [ 247, 249, 249 ]] | |
# with open('./static/config.yaml') as file: | |
# config = yaml.load(file, Loader=SafeLoader) | |
config = json.loads(st.secrets["CONFIG"]) | |
authenticator = Authenticate( | |
config['credentials'], | |
config['cookie']['name'], | |
config['cookie']['key'], | |
config['cookie']['expiry_days'], | |
config['preauthorized'] | |
) | |
name, authentication_status, username = authenticator.login('Login', 'main') | |
# images_lst = os.listdir(image_dir) | |
# images = [] | |
# for file in images_lst: | |
# ipath = os.path.join(os.path.abspath(image_dir), file) | |
# with open(ipath, "rb") as image: | |
# encoded = base64.b64encode(image.read()).decode() | |
# images.append(f"data:image/jpeg;base64,{encoded}") | |
def get_detr_predictions(image, thresh): | |
with torch.no_grad(): | |
inputs = detr_preprocessor(images=image, return_tensors="pt") | |
outputs = detr_model(**inputs) | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = detr_preprocessor.post_process_object_detection( | |
outputs, threshold=float(thresh), target_sizes=target_sizes)[0] | |
return results | |
def add_label(img, | |
label, | |
bbox, | |
draw_bg=True, | |
text_bg_color=(255, 255, 255), | |
text_color=(0, 0, 0), | |
top=True): | |
"""adds label, inside or outside the rectangle | |
Parameters | |
---------- | |
img : ndarray | |
the image on which the label is to be written, preferably the image with the rectangular bounding box drawn | |
label : str | |
the text (label) to be written | |
bbox : list | |
a list containing x_min, y_min, x_max and y_max of the rectangle positions | |
draw_bg : bool, optional | |
if True, draws the background of the text, else just the text is written, by default True | |
text_bg_color : tuple, optional | |
the background color of the label that is filled, by default (255, 255, 255) | |
text_color : tuple, optional | |
color of the text (label) to be written, by default (0, 0, 0) | |
top : bool, optional | |
if True, writes the label on top of the bounding box, else inside, by default True | |
Returns | |
------- | |
ndarray | |
the image with the label written | |
""" | |
text_width = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0][0] | |
if top: | |
label_bg = [bbox[0], bbox[1], bbox[0] + text_width, bbox[1] + 30] | |
if draw_bg: | |
cv2.rectangle(img, (label_bg[0], label_bg[1]), | |
(label_bg[2] + 5, label_bg[3]), text_bg_color, -1) | |
cv2.putText(img, label, (bbox[0] + 5, bbox[1] - 5), | |
cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2) | |
else: | |
label_bg = [bbox[0], bbox[1], bbox[0] + text_width, bbox[1] + 30] | |
if draw_bg: | |
cv2.rectangle(img, (label_bg[0], label_bg[1]), | |
(label_bg[2] + 5, label_bg[3]), text_bg_color, -1) | |
cv2.putText(img, label, (bbox[0] + 5, bbox[1] - 5 + 30), | |
cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2) | |
return img | |
def image_checkup_v4(ipath, thresh, show_count): | |
imgOrig = ipath.convert("RGB") | |
image = imgOrig.copy() | |
# crop_name, crop_conf, crop_id = get_crop(image) | |
detr_results = get_detr_predictions(image, thresh) | |
final_results = [] | |
all_predictions = { | |
"Name": "Detailed View", | |
"Value": "" | |
} | |
# result_data = {crop_name: crop_conf, "Inspection_data": []} | |
img_with_box = np.array(image).copy() | |
for idx, label_id in enumerate(detr_results["labels"].numpy()): | |
pred_score = round(detr_results["scores"].numpy()[idx], 2) | |
predicted_label = detr_model.config.id2label[label_id] | |
# if float(pred_score) > 50: | |
bbox = list(np.array(detr_results["boxes"].numpy()[idx], dtype=int)) | |
img_with_box = bbv.draw_rectangle(img_with_box, bbox, bbox_color=colors[label_id]) | |
# img_with_box = bbv.add_label(img_with_box, label=f"{predicted_label} : {pred_score}", bbox=bbox, top=False) | |
if show_count: | |
img_with_box = bbv.add_label( | |
img_with_box, f"{idx + 1}", bbox, draw_bg=True, top=True) | |
else: | |
img_with_box = bbv.add_label( | |
img_with_box, f"", bbox, draw_bg=False, top=True) | |
final_results.append( | |
{"prediction": predicted_label, | |
"confidence": pred_score, | |
"color": colors[label_id] | |
} | |
) | |
all_predictions["Value"] += f"\n{idx + 1}. {predicted_label.split('_')[-1]} - {round(pred_score, 2)}%\n" | |
if len(final_results) > 0: | |
df = pd.DataFrame(final_results) | |
info = df["prediction"].value_counts() | |
resized_seg = cv2.resize(img_with_box, imgOrig.size) | |
new_res = [] | |
for k, v in dict(info).items(): | |
tmp = {} | |
prd_id = detr_model.config.label2id[k] | |
tmp["Insect"] = k | |
tmp["Count"] = v | |
tmp["Color"] = colors[int(prd_id)] | |
new_res.append(tmp) | |
return new_res, resized_seg | |
return [], img_with_box | |
def add_logo(logo_path, width, height): | |
"""Read and return a resized logo""" | |
logo = Image.open(logo_path) | |
# modified_logo = logo.resize((width, height)) | |
return logo | |
# st.write("<hr/>", unsafe_allow_html=True) | |
if st.session_state["authentication_status"]: | |
with st.sidebar: | |
my_logo = add_logo(logo_path="./static//FarmGyan logo_1.png", width=50, height=60) | |
st.image(my_logo) | |
ucol, bcol = st.columns([3, 2]) | |
ucol.write(f'## Welcome *{st.session_state["name"]}*') | |
with bcol: | |
authenticator.logout('Logout', 'main') | |
st.write("<hr/>", unsafe_allow_html=True) | |
st.title(":seedling: FarmGyan | Insects Scouting") | |
st.write("<hr/>", unsafe_allow_html=True) | |
st.write("## π Upload image for prediction") | |
uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"]) | |
st.write("<hr/>", unsafe_allow_html=True) | |
with st.spinner(text='In progress'): | |
st.sidebar.write("## βοΈ Configurations") | |
st.sidebar.write("<hr/>", unsafe_allow_html=True) | |
st.sidebar.write("#### Prediction Threshold") | |
thresh = st.sidebar.slider("Threshold", 0.0, 1.0, 0.7, 0.1) | |
st.sidebar.write("#### Boxes Count") | |
show_count = st.sidebar.checkbox("Show Count") | |
if uploaded_file is not None: | |
clicked = None | |
image = Image.open(uploaded_file).convert("RGB") | |
predicted_data, result_image = image_checkup_v4(image, thresh, show_count) | |
# print(predicted_data) | |
col, col1 = st.columns([2, 4]) | |
feedback_submitted = False # Initialize the flag | |
with col: | |
st.subheader("π― Predicted Labels") | |
st.write(f"<h3>Total Count : {sum([d['Count'] for d in predicted_data])}</h3>", unsafe_allow_html=True) | |
for i, d in enumerate(predicted_data): | |
# Create HTML markup with style information | |
html_string = f""" | |
<div style="display: flex; align-items: center;"> | |
<b style="margin-right: 15px">{i + 1}. </b> | |
<div style="background-color: rgb({d["Color"][0]}, {d["Color"][1]}, {d["Color"][2]}); width: 20px; height: 20px; border: 1px solid black; margin-right: 10px;"></div> | |
<p style="margin-top: 15px"><b>{d["Insect"]} : {d["Count"]} </b></p> | |
</div> | |
""" | |
st.markdown(html_string, unsafe_allow_html=True) | |
st.write("<hr/>", unsafe_allow_html=True) | |
with col1: | |
st.subheader("π Predicted Image") | |
st.write("<br/>", unsafe_allow_html=True) | |
st.image(result_image) | |