Spaces:
Runtime error
Runtime error
import numpy as np | |
from fastapi import FastAPI, File, UploadFile, HTTPException | |
import tensorflow as tf | |
from PIL import Image | |
from io import BytesIO | |
from ultralytics import YOLO | |
import cv2 | |
from datetime import datetime | |
from fastapi.responses import FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pathlib import Path | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
labels = [] | |
classification_model = tf.keras.models.load_model('./models.h5') | |
detection_model = YOLO('./best.pt') | |
with open("labels.txt") as f: | |
for line in f: | |
labels.append(line.replace('\n', '')) | |
def classify_image(img): | |
# Resize the input image to the expected shape (224, 224) | |
img_array = np.asarray(img.resize((224, 224)))[..., :3] | |
img_array = img_array.reshape((1, 224, 224, 3)) # Add batch dimension | |
img_array = tf.keras.applications.efficientnet.preprocess_input(img_array) | |
prediction = classification_model.predict(img_array).flatten() | |
confidences = {labels[i]: float(prediction[i]) for i in range(90)} | |
# Sort the confidences dictionary by value and get the top 3 items | |
# top_3_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:3]) | |
return confidences | |
def animal_detect_and_classify(img_path): | |
# Read the image using Pillow | |
img = Image.open(img_path) | |
# Pass the image through the detection model and get the result | |
detect_results = detection_model(np.array(img)) | |
combined_results = [] | |
# Iterate over detections | |
for result in detect_results: | |
flag = False | |
for box in result.boxes: | |
flag = True | |
# Crop the Region of Interest (RoI) | |
x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
detect_img = img.crop((x1, y1, x2, y2)).resize((224, 224)) | |
# Convert the image to a numpy array | |
inp_array = np.array(detect_img) | |
# Reshape the array to match the expected input shape | |
inp_array = inp_array.reshape((-1, 224, 224, 3)) | |
# Preprocess the input array | |
inp_array = tf.keras.applications.efficientnet.preprocess_input(inp_array) | |
# Make predictions using the classification model | |
prediction = classification_model.predict(inp_array) | |
# Map predictions to labels | |
threshold = 0.66 | |
predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= threshold else "animal" for pred in prediction] | |
print(predicted_labels) | |
combined_results.append(((x1, y1, x2, y2), predicted_labels)) | |
if flag: | |
continue | |
# If no detections found, consider the whole image | |
x2, y2 = img.size | |
detect_img = img.resize((224, 224)) | |
inp_array = np.array(detect_img).reshape((-1, 224, 224, 3)) | |
inp_array = tf.keras.applications.efficientnet.preprocess_input(inp_array) | |
prediction = classification_model.predict(inp_array) | |
threshold = 0.66 | |
predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= threshold else "unknown" for pred in prediction] | |
combined_results.append(((0, 0, x2, y2), predicted_labels)) | |
return combined_results | |
def generate_color(class_name): | |
# Generate a hash from the class name | |
color_hash = hash(class_name) | |
print(color_hash) | |
# Normalize the hash value to fit within the range of valid color values (0-255) | |
color_hash = abs(color_hash) % 16777216 | |
R = color_hash//(256*256) | |
G = (color_hash//256) % 256 | |
B = color_hash % 256 | |
# Convert the hash value to RGB color format | |
color = (R, G, B) | |
return color | |
def plot_detected_rectangles(image, detections, output_path): | |
# Create a copy of the image to draw on | |
img_with_rectangles = image.copy() | |
# Iterate over each detected rectangle and its corresponding class name | |
for rectangle, class_names in detections: | |
if class_names[0] == "unknown": | |
continue | |
# Extract the coordinates of the rectangle | |
x1, y1, x2, y2 = rectangle | |
# Generate a random color | |
color = generate_color(class_names[0]) | |
# Draw the rectangle on the image | |
cv2.rectangle(img_with_rectangles, (x1, y1), (x2, y2), color, 2) | |
# Put the class names above the rectangle | |
for i, class_name in enumerate(class_names): | |
cv2.putText(img_with_rectangles, class_name, (x1, y1 - 10 - i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
# Show the image with rectangles and class names | |
cv2.imwrite(output_path, img_with_rectangles) | |
async def predict_v2(file: UploadFile = File(...)): | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_") | |
filename = timestamp + file.filename | |
contents = await file.read() | |
image = Image.open(BytesIO(contents)) | |
image.save("input/" + filename) | |
detections = animal_detect_and_classify("input/" + filename) | |
class_names = [class_name[0] for _, class_name in detections] | |
plot_detected_rectangles(cv2.imread("input/" + filename), detections, "output/" + filename) | |
return { | |
"message": "Detection and classification completed successfully", | |
"out": filename, | |
"class_names": class_names | |
} | |
IMAGE_DIR = Path("output") | |
async def get_image(image_name: str): | |
# Sanitize the image_name to prevent directory traversal attacks | |
if "../" in image_name: | |
raise HTTPException(status_code=400, detail="Invalid image name") | |
# Construct the image path | |
image_path = IMAGE_DIR / image_name | |
# Check if the image exists | |
if not image_path.exists() or not image_path.is_file(): | |
raise HTTPException(status_code=404, detail="Image not found") | |
# Return the image file | |
return FileResponse(image_path) | |
async def predict(file: bytes = File(...)): | |
img = Image.open(BytesIO(file)) | |
confidences = classify_image(img) | |
return confidences | |