FP-KCV / appv2.py
xcurv's picture
Upload appv2.py
37bed94 verified
raw
history blame
4.32 kB
import streamlit as st
import numpy as np
import tensorflow as tf
from PIL import Image
from ultralytics import YOLO
import cv2
# Load models
st.sidebar.title("Settings")
classification_model = tf.keras.models.load_model('./models.h5')
detection_model = YOLO('./best.pt')
# Load labels
labels = []
with open("labels.txt") as f:
labels = [line.strip() for line in f]
# Function to classify image
def classify_image(img):
img = img.resize((224, 224)) # Resize image
img_array = np.array(img)
img_array = img_array.reshape((-1, 224, 224, 3))
img_array = tf.keras.applications.efficientnet.preprocess_input(img_array)
prediction = classification_model.predict(img_array).flatten()
confidences = {labels[i]: float(prediction[i]) for i in range(90)}
return confidences
# Function to detect animals and classify them
def animal_detect_and_classify(img, detect_results):
img = np.array(img)
combined_results = []
for result in detect_results:
for box in result.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
detect_img = img[y1:y2, x1:x2]
detect_img = cv2.resize(detect_img, (224, 224))
inp_array = np.array(detect_img).reshape((-1, 224, 224, 3))
inp_array = tf.keras.applications.efficientnet.preprocess_input(inp_array)
prediction = classification_model.predict(inp_array)
confidences_classification = {labels[i]: float(prediction[0][i]) for i in range(90)}
predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= 0.66 else "animal" for pred in prediction]
combined_results.append(((x1, y1, x2, y2), predicted_labels))
return combined_results
# Function to generate color for bounding boxes
def generate_color(class_name):
color_hash = abs(hash(class_name)) % 16777216
R = color_hash // (256 * 256)
G = (color_hash // 256) % 256
B = color_hash % 256
return (R, G, B)
# Function to draw bounding boxes
def plot_detected_rectangles(image, detections):
img_with_rectangles = np.array(image).copy()
for rectangle, class_names in detections:
if class_names[0] == "unknown":
continue
x1, y1, x2, y2 = rectangle
color = generate_color(class_names[0])
cv2.rectangle(img_with_rectangles, (x1, y1), (x2, y2), color, 2)
for i, class_name in enumerate(class_names):
cv2.putText(img_with_rectangles, class_name, (x1, y1 - 10 - i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return Image.fromarray(img_with_rectangles)
# Function to run object detection
def detection_image(img, conf_threshold, iou_threshold):
results = detection_model.predict(
source=img,
conf=conf_threshold,
iou=iou_threshold,
imgsz=640,
)
combined_results = animal_detect_and_classify(img, results)
plotted_image = plot_detected_rectangles(img, combined_results)
return plotted_image
# Streamlit UI
st.title("Animal Image Processing")
tab1, tab2 = st.tabs(["Image Classification", "Object Detection"])
with tab1:
st.header("Image Classification")
uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
predictions = classify_image(image)
sorted_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:3]
st.subheader("Top Predictions:")
for label, confidence in sorted_preds:
st.write(f"**{label}**: {confidence*100:.2f}%")
with tab2:
st.header("Object Detection")
uploaded_file_detect = st.file_uploader("Upload an image for object detection...", type=["jpg", "jpeg", "png"])
conf_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.25)
iou_threshold = st.slider("IoU Threshold", 0.0, 1.0, 0.45)
if uploaded_file_detect is not None:
image = Image.open(uploaded_file_detect)
st.image(image, caption="Uploaded Image", use_container_width=True)
detected_image = detection_image(image, conf_threshold, iou_threshold)
st.image(detected_image, caption="Detected Objects", use_container_width=True)