xcurv commited on
Commit
37bed94
·
verified ·
1 Parent(s): bc94c15

Upload appv2.py

Browse files
Files changed (1) hide show
  1. appv2.py +121 -0
appv2.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from PIL import Image
5
+ from ultralytics import YOLO
6
+ import cv2
7
+
8
+ # Load models
9
+ st.sidebar.title("Settings")
10
+ classification_model = tf.keras.models.load_model('./models.h5')
11
+ detection_model = YOLO('./best.pt')
12
+
13
+ # Load labels
14
+ labels = []
15
+ with open("labels.txt") as f:
16
+ labels = [line.strip() for line in f]
17
+
18
+ # Function to classify image
19
+ def classify_image(img):
20
+ img = img.resize((224, 224)) # Resize image
21
+ img_array = np.array(img)
22
+ img_array = img_array.reshape((-1, 224, 224, 3))
23
+ img_array = tf.keras.applications.efficientnet.preprocess_input(img_array)
24
+
25
+ prediction = classification_model.predict(img_array).flatten()
26
+ confidences = {labels[i]: float(prediction[i]) for i in range(90)}
27
+
28
+ return confidences
29
+
30
+ # Function to detect animals and classify them
31
+ def animal_detect_and_classify(img, detect_results):
32
+ img = np.array(img)
33
+ combined_results = []
34
+
35
+ for result in detect_results:
36
+ for box in result.boxes:
37
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
38
+ detect_img = img[y1:y2, x1:x2]
39
+ detect_img = cv2.resize(detect_img, (224, 224))
40
+ inp_array = np.array(detect_img).reshape((-1, 224, 224, 3))
41
+ inp_array = tf.keras.applications.efficientnet.preprocess_input(inp_array)
42
+
43
+ prediction = classification_model.predict(inp_array)
44
+ confidences_classification = {labels[i]: float(prediction[0][i]) for i in range(90)}
45
+ predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= 0.66 else "animal" for pred in prediction]
46
+
47
+ combined_results.append(((x1, y1, x2, y2), predicted_labels))
48
+
49
+ return combined_results
50
+
51
+ # Function to generate color for bounding boxes
52
+ def generate_color(class_name):
53
+ color_hash = abs(hash(class_name)) % 16777216
54
+ R = color_hash // (256 * 256)
55
+ G = (color_hash // 256) % 256
56
+ B = color_hash % 256
57
+ return (R, G, B)
58
+
59
+ # Function to draw bounding boxes
60
+ def plot_detected_rectangles(image, detections):
61
+ img_with_rectangles = np.array(image).copy()
62
+
63
+ for rectangle, class_names in detections:
64
+ if class_names[0] == "unknown":
65
+ continue
66
+
67
+ x1, y1, x2, y2 = rectangle
68
+ color = generate_color(class_names[0])
69
+ cv2.rectangle(img_with_rectangles, (x1, y1), (x2, y2), color, 2)
70
+
71
+ for i, class_name in enumerate(class_names):
72
+ cv2.putText(img_with_rectangles, class_name, (x1, y1 - 10 - i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
73
+
74
+ return Image.fromarray(img_with_rectangles)
75
+
76
+ # Function to run object detection
77
+ def detection_image(img, conf_threshold, iou_threshold):
78
+ results = detection_model.predict(
79
+ source=img,
80
+ conf=conf_threshold,
81
+ iou=iou_threshold,
82
+ imgsz=640,
83
+ )
84
+
85
+ combined_results = animal_detect_and_classify(img, results)
86
+ plotted_image = plot_detected_rectangles(img, combined_results)
87
+ return plotted_image
88
+
89
+ # Streamlit UI
90
+ st.title("Animal Image Processing")
91
+
92
+ tab1, tab2 = st.tabs(["Image Classification", "Object Detection"])
93
+
94
+ with tab1:
95
+ st.header("Image Classification")
96
+ uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"])
97
+
98
+ if uploaded_file is not None:
99
+ image = Image.open(uploaded_file)
100
+ st.image(image, caption="Uploaded Image", use_container_width=True)
101
+
102
+ predictions = classify_image(image)
103
+ sorted_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:3]
104
+
105
+ st.subheader("Top Predictions:")
106
+ for label, confidence in sorted_preds:
107
+ st.write(f"**{label}**: {confidence*100:.2f}%")
108
+
109
+ with tab2:
110
+ st.header("Object Detection")
111
+ uploaded_file_detect = st.file_uploader("Upload an image for object detection...", type=["jpg", "jpeg", "png"])
112
+
113
+ conf_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.25)
114
+ iou_threshold = st.slider("IoU Threshold", 0.0, 1.0, 0.45)
115
+
116
+ if uploaded_file_detect is not None:
117
+ image = Image.open(uploaded_file_detect)
118
+ st.image(image, caption="Uploaded Image", use_container_width=True)
119
+
120
+ detected_image = detection_image(image, conf_threshold, iou_threshold)
121
+ st.image(detected_image, caption="Detected Objects", use_container_width=True)