File size: 4,378 Bytes
293ead4
 
 
 
 
 
 
 
 
 
 
2a2bae3
 
 
 
 
 
 
 
 
 
 
 
 
 
293ead4
 
2a2bae3
293ead4
 
 
 
 
 
 
 
 
 
 
 
 
 
2a2bae3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293ead4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# cal.py

import torch
from ultralytics import YOLO
import cv2
import numpy as np
import matplotlib.pyplot as plt
import streamlit as st
# Configuration class
class Config:
    
    CLASSES = ['asparagus', 'avocados', 'broccoli', 'cabbage',        #4
               'celery', 'cucumber', 'green_apples', 'green_beans', #4
               'green_capsicum', 'green_grapes', 'kiwifruit', #3
               'lettuce', 'limes', 'peas', 'spinach',  #4
               'Banana', 'Cauliflower', 'Date', 'Garlic', #4
               'Ginger', 'Mushroom', 'Onion', 'Parsnip', #4
               'Peach', 'Pear', 'Potato', 'Turnip', #4
               'Beetroot', 'Blackberry', 'Blueberry', 'Cherry', #4
               'Eggplant', 'Plum', 'Purple asparagus', 'Purple grapes',  #4
               'Radish', 'Raspberry', 'Red Apple', 'Red Grape', #4
               'Red cabbage', 'Red capsicum', 'Strawberry', 'Tomato', #4
               'Watermelon', 'apricot', 'carrot', 'corn', #4
               'grapefruit', 'lemon', 'mango', 'nectarine', #4
               'orange', 'pineapple', 'pumpkin', 'sweet_potato'] #4
    
    CALORIES_DICT = {
        # Green foods (existing)
        'asparagus': 20,
        'avocados': 160,
        'broccoli': 55,
        'cabbage': 25,
        'celery': 16,
        'cucumber': 16,
        'green_apples': 52,
        'green_beans': 31,
        'green_capsicum': 20,
        'green_grapes': 69,
        'kiwifruit': 61,
        'lettuce': 15,
        'limes': 30,
        'peas': 81,
        'spinach': 23,
        
        # White/Beige foods
        'Banana': 89,
        'Cauliflower': 25,
        'Date': 282,
        'Garlic': 149,
        'Ginger': 80,
        'Mushroom': 22,
        'Onion': 40,
        'Parsnip': 75,
        'Peach': 39,
        'Pear': 57,
        'Potato': 77,
        'Turnip': 28,
        
        # Purple/Red foods
        'Beetroot': 43,
        'Blackberry': 43,
        'Blueberry': 57,
        'Cherry': 50,
        'Eggplant': 25,
        'Plum': 46,
        'Purple asparagus': 20,
        'Purple grapes': 69,
        'Radish': 16,
        'Raspberry': 52,
        'Red Apple': 52,
        'Red Grape': 69,
        'Red cabbage': 31,
        'Red capsicum': 31,
        'Strawberry': 32,
        'Tomato': 18,
        'Watermelon': 30,
        
        # Orange/Yellow foods
        'apricot': 48,
        'carrot': 41,
        'corn': 86,
        'grapefruit': 42,
        'lemon': 29,
        'mango': 60,
        'nectarine': 44,
        'orange': 47,
        'pineapple': 50,
        'pumpkin': 26,
        'sweet_potato': 86
    }

# Load the model
@st.cache_resource
def load_model():
    model = YOLO('./best.pt')
    return model

# Function to make predictions on a single image
def predict_image(image_path, model, conf_threshold=0.03):
    # Perform inference on the image
    results = model.predict(
        source=image_path,
        imgsz=640,
        conf=conf_threshold
    )
    
    # Load the image for visualization
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # To store detailed information about detections
    detection_details = []
    
    # Iterate over detections
    for result in results[0].boxes.data:
        # Extract bounding box coordinates, confidence score, and class ID
        x1, y1, x2, y2, confidence, class_id = result.cpu().numpy()
        
        # Draw the bounding box with top confidence score
        cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color=(0, 255, 0), thickness=2)
        label = f"{Config.CLASSES[int(class_id)]}: {confidence:.2f}"
        cv2.putText(image, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), thickness=1)
        
        # Save details for printing below
        detection_details.append({
            "class": Config.CLASSES[int(class_id)],
            "top_confidence": confidence,
            "bbox": (x1, y1, x2, y2)
        })
    
    return image, detection_details

# Function to calculate detected items and their calories
def calculate_calories(detection_details):
    detected_items = []
    
    for det in detection_details:
        item = det["class"]
        calories = Config.CALORIES_DICT[item]
        confidence = det["top_confidence"]
        detected_items.append((item, calories, confidence))
    
    return detected_items