Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- app.py +148 -0
- regression_model.pth +3 -0
- yolo.pt +3 -0
app.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import torchvision
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from ultralytics import YOLO
|
11 |
+
import cv2
|
12 |
+
|
13 |
+
|
14 |
+
# Define the Regression Model class
|
15 |
+
class RegressionModel(nn.Module):
|
16 |
+
def __init__(self):
|
17 |
+
super(RegressionModel, self).__init__()
|
18 |
+
# Load pretrained ResNet101
|
19 |
+
resnet = torchvision.models.resnet101(pretrained=True)
|
20 |
+
# Remove the last fully connected layer
|
21 |
+
self.features = nn.Sequential(*list(resnet.children())[:-1])
|
22 |
+
# Replace the last layer with regression layers
|
23 |
+
self.regressor1 = nn.Linear(2048, 512)
|
24 |
+
self.regressor2 = nn.Linear(512, 64)
|
25 |
+
self.regressor3 = nn.Linear(64, 1)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
x = self.features(x)
|
29 |
+
x = x.view(x.size(0), -1)
|
30 |
+
x = self.regressor1(x)
|
31 |
+
x = nn.GELU()(x)
|
32 |
+
x = self.regressor2(x)
|
33 |
+
x = nn.GELU()(x)
|
34 |
+
x = self.regressor3(x)
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
# Load the pre-trained model state dictionary
|
39 |
+
model_state_dict = torch.load("regression_model.pth", map_location="cpu")
|
40 |
+
# Instantiate the RegressionModel
|
41 |
+
model_reg = RegressionModel()
|
42 |
+
# Load the state dictionary into the model
|
43 |
+
model_reg.load_state_dict(model_state_dict)
|
44 |
+
# Set the model to evaluation mode
|
45 |
+
model_reg.eval()
|
46 |
+
|
47 |
+
# Define transformations for test images
|
48 |
+
mean = [0.485, 0.456, 0.406]
|
49 |
+
std = [0.229, 0.224, 0.225]
|
50 |
+
test_transforms = transforms.Compose(
|
51 |
+
[
|
52 |
+
transforms.Resize((224, 224)),
|
53 |
+
transforms.ToTensor(),
|
54 |
+
transforms.Normalize(mean=mean, std=std),
|
55 |
+
]
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
# Define the regression function
|
60 |
+
def regression(image):
|
61 |
+
img_numpy = test_transforms(Image.fromarray(image))
|
62 |
+
image_tensor = img_numpy.unsqueeze(0) # Add batch dimension
|
63 |
+
with torch.no_grad():
|
64 |
+
output = model_reg(image_tensor)
|
65 |
+
return output.item()
|
66 |
+
|
67 |
+
|
68 |
+
# Define the object detection function
|
69 |
+
def hugg_face(img):
|
70 |
+
# Load YOLO model
|
71 |
+
model = YOLO("yolo.pt")
|
72 |
+
labels = [
|
73 |
+
"freshpeach",
|
74 |
+
"freshlemon",
|
75 |
+
"rottenpeach",
|
76 |
+
"rotten lemon",
|
77 |
+
"freshmandarin",
|
78 |
+
"rottenmandarin",
|
79 |
+
"freshtomato",
|
80 |
+
"rottentomato",
|
81 |
+
"freshcucumber",
|
82 |
+
"rottencucumber",
|
83 |
+
]
|
84 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
85 |
+
results = model(img)
|
86 |
+
img_label_results = []
|
87 |
+
img_2 = img.copy()
|
88 |
+
|
89 |
+
# Process each detection result
|
90 |
+
for result in results:
|
91 |
+
for i, cls in enumerate(result.boxes.cls):
|
92 |
+
crop_img = img[
|
93 |
+
int(result.boxes.xyxy[i][1]) : int(result.boxes.xyxy[i][3]),
|
94 |
+
int(result.boxes.xyxy[i][0]) : int(result.boxes.xyxy[i][2]),
|
95 |
+
]
|
96 |
+
cv2.rectangle(
|
97 |
+
img_2,
|
98 |
+
(int(result.boxes.xyxy[i][0]), int(result.boxes.xyxy[i][1])),
|
99 |
+
(int(result.boxes.xyxy[i][2]), int(result.boxes.xyxy[i][3])),
|
100 |
+
(0, 255, 0),
|
101 |
+
2,
|
102 |
+
)
|
103 |
+
cv2.putText(
|
104 |
+
img_2,
|
105 |
+
labels[int(cls)] + str(i),
|
106 |
+
(int(result.boxes.xyxy[i][0]), int(result.boxes.xyxy[i][1])),
|
107 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
108 |
+
3,
|
109 |
+
(0, 255, 0),
|
110 |
+
2,
|
111 |
+
cv2.LINE_AA,
|
112 |
+
)
|
113 |
+
img_label_results.append(
|
114 |
+
{"label": labels[int(cls)] + str(i), "crop_img": crop_img}
|
115 |
+
)
|
116 |
+
|
117 |
+
img_2_pil = Image.fromarray(cv2.cvtColor(img_2, cv2.COLOR_BGR2RGB))
|
118 |
+
regression_results = []
|
119 |
+
|
120 |
+
# Perform regression on each cropped image
|
121 |
+
for item in img_label_results:
|
122 |
+
label = item["label"]
|
123 |
+
cropped_img = item["crop_img"]
|
124 |
+
regression_output = regression(cropped_img)
|
125 |
+
# Append regression results to the list
|
126 |
+
regression_results.append(
|
127 |
+
{"label": label, "Rotten Part Percentage": round(regression_output,2)}
|
128 |
+
)
|
129 |
+
return img_2_pil, regression_results
|
130 |
+
|
131 |
+
|
132 |
+
# Define Gradio interface
|
133 |
+
inputs = gr.Image(type="pil")
|
134 |
+
outputs = [
|
135 |
+
gr.Image(type="pil", label="Detection Result"), # Output for the segmented image
|
136 |
+
gr.Textbox(label="Regression Results"), # Output for the regression results
|
137 |
+
]
|
138 |
+
|
139 |
+
app = gr.Interface(
|
140 |
+
fn=hugg_face,
|
141 |
+
inputs=inputs,
|
142 |
+
outputs=outputs,
|
143 |
+
title="Smart Fridge with Regression",
|
144 |
+
description="Rotten part regression results",
|
145 |
+
)
|
146 |
+
|
147 |
+
# Launch the app
|
148 |
+
app.launch(share=True)
|
regression_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4f3cfee4eb69b74a4da298cc2bb504d07acfc278e0cc6b79065f9454e9db839
|
3 |
+
size 174966352
|
yolo.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37a2c1e3ffdfb2b69d2e3c5f127722e5940edf16bc1b8176be4495751436fcfc
|
3 |
+
size 22516185
|