SamDaLamb commited on
Commit
5a89f37
·
verified ·
1 Parent(s): b5b4697

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -59
app.py CHANGED
@@ -8,73 +8,157 @@ import gradio as gr
8
  import numpy as np
9
  import requests
10
  from PIL import Image
 
 
 
 
 
11
  import torch
12
  from torchvision import transforms
13
  from PIL import Image
 
 
14
 
15
  # Load the YOLO model
16
- model_path = "./best-model.torchscript"
17
  model = torch.jit.load(model_path, map_location=torch.device("cpu"))
18
  model.eval()
19
 
20
- # Initialize your pose estimation model
21
- yolo_nas_pose = models.get("best.pt",
22
- num_classes=1,
23
- checkpoint_path="./best.pt")
24
-
25
- def process_and_predict(url=None,
26
- image=None,
27
- confidence=0.5,
28
- iou=0.5):
29
- # If a URL is provided, use it directly for prediction
30
- if url is not None and url.strip() != "":
31
- response = requests.get(url)
32
- image = Image.open(BytesIO(response.content))
33
- image = np.array(image)
34
- result = yolo_nas_pose.predict(image, conf=confidence,iou=iou)
35
- # If a file is uploaded, read it, convert it to a numpy array and use it for prediction
36
- elif image is not None:
37
- result = yolo_nas_pose.predict(image, conf=confidence,iou=iou)
38
- else:
39
- return None # If no input is provided, return None
40
-
41
- # Extract prediction data
42
- image_prediction = result._images_prediction_lst[0]
43
-
44
- pose_data = image_prediction.prediction
45
-
46
- # Visualize the prediction
47
- output_image = PoseVisualization.draw_poses(
48
- image=image_prediction.image,
49
- poses=pose_data.poses,
50
- boxes=pose_data.bboxes_xyxy,
51
- scores=pose_data.scores,
52
- is_crowd=None,
53
- edge_links=pose_data.edge_links,
54
- edge_colors=pose_data.edge_colors,
55
- keypoint_colors=pose_data.keypoint_colors,
56
- joint_thickness=2,
57
- box_thickness=2,
58
- keypoint_radius=5
59
- )
60
-
61
- blank_image = np.zeros_like(image_prediction.image)
62
-
63
- skeleton_image = PoseVisualization.draw_poses(
64
- image=blank_image,
65
- poses=pose_data.poses,
66
- boxes=pose_data.bboxes_xyxy,
67
- scores=pose_data.scores,
68
- is_crowd=None,
69
- edge_links=pose_data.edge_links,
70
- edge_colors=pose_data.edge_colors,
71
- keypoint_colors=pose_data.keypoint_colors,
72
- joint_thickness=2,
73
- box_thickness=2,
74
- keypoint_radius=5
75
- )
76
-
77
- return output_image, skeleton_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def greet(name):
80
  return "Hello " + name + "!!"
 
8
  import numpy as np
9
  import requests
10
  from PIL import Image
11
+
12
+ import gradio as gr
13
+ import cv2
14
+ import tempfile
15
+ import numpy as np
16
  import torch
17
  from torchvision import transforms
18
  from PIL import Image
19
+ import matplotlib.pyplot as plt
20
+ from io import BytesIO
21
 
22
  # Load the YOLO model
23
+ model_path = "./best.pt"
24
  model = torch.jit.load(model_path, map_location=torch.device("cpu"))
25
  model.eval()
26
 
27
+ transform = transforms.Compose([
28
+ transforms.Resize((640, 640)),
29
+ transforms.ToTensor(),
30
+ ])
31
+
32
+ OBJECT_NAMES = ['enemies']
33
+
34
+ def detect_objects_in_image(image):
35
+ img_tensor = transform(image).unsqueeze(0)
36
+ orig_w, orig_h = image.size
37
+
38
+ with torch.no_grad():
39
+ pred = model(img_tensor)[0]
40
+
41
+ if isinstance(pred[0], torch.Tensor):
42
+ pred = [p.cpu().numpy() for p in pred]
43
+
44
+ pred = np.concatenate(pred, axis=0)
45
+ conf_thres = 0.25
46
+ mask = pred[:, 4] > conf_thres
47
+ pred = pred[mask]
48
+
49
+ if len(pred) == 0:
50
+ return Image.fromarray(np.array(image)), None # Return only image and None for graph
51
+
52
+ boxes, scores, class_probs = pred[:, :4], pred[:, 4], pred[:, 5:]
53
+ class_ids = np.argmax(class_probs, axis=1)
54
+
55
+ boxes[:, 0] = boxes[:, 0] - (boxes[:, 2] / 2)
56
+ boxes[:, 1] = boxes[:, 1] - (boxes[:, 3] / 2)
57
+ boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
58
+ boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
59
+
60
+ boxes[:, [0, 2]] *= orig_w / 640
61
+ boxes[:, [1, 3]] *= orig_h / 640
62
+ boxes = np.clip(boxes, 0, [orig_w, orig_h, orig_w, orig_h])
63
+
64
+ indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), conf_thres, 0.5)
65
+
66
+ object_counts = {name: 0 for name in OBJECT_NAMES}
67
+ img_array = np.array(image)
68
+
69
+ if len(indices) > 0:
70
+ for i in indices.flatten():
71
+ x1, y1, x2, y2 = map(int, boxes[i])
72
+ cls = class_ids[i]
73
+ object_name = OBJECT_NAMES[cls] if cls < len(OBJECT_NAMES) else f"Unknown ({cls})"
74
+ if object_name in object_counts:
75
+ object_counts[object_name] += 1
76
+ cv2.rectangle(img_array, (x1, y1), (x2, y2), (0, 255, 0), 2)
77
+ cv2.putText(img_array, f"{object_name}: {scores[i]:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
78
+
79
+ # Generate and return graph instead of dictionary
80
+ graph_image = generate_vehicle_count_graph(object_counts)
81
+
82
+ return Image.fromarray(img_array), graph_image # Now returning only 2 outputs
83
+
84
+
85
+ # def generate_vehicle_count_graph(object_counts):
86
+ # color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1']
87
+
88
+ # fig, ax = plt.subplots(figsize=(8, 5))
89
+ # labels = list(object_counts.keys())
90
+ # values = list(object_counts.values())
91
+
92
+ # ax.bar(labels, values, color=color_palette[:len(labels)])
93
+
94
+ # ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold')
95
+ # ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold')
96
+ # ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold')
97
+
98
+ # plt.xticks(rotation=45, ha='right', fontsize=10)
99
+ # plt.yticks(fontsize=10)
100
+
101
+ # plt.tight_layout()
102
+
103
+ # buf = BytesIO()
104
+ # plt.savefig(buf, format='png')
105
+ # buf.seek(0)
106
+
107
+ # return Image.open(buf)
108
+
109
+ def generate_vehicle_count_graph(object_counts):
110
+ color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1']
111
+
112
+ fig, ax = plt.subplots(figsize=(8, 5))
113
+ labels = list(object_counts.keys())
114
+ values = list(object_counts.values())
115
+
116
+ ax.bar(labels, values, color=color_palette[:len(labels)])
117
+ ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold')
118
+ ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold')
119
+ ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold')
120
+
121
+ plt.xticks(rotation=45, ha='right', fontsize=10)
122
+ plt.yticks(fontsize=10)
123
+ plt.tight_layout()
124
+
125
+ buf = BytesIO()
126
+ plt.savefig(buf, format='png')
127
+ buf.seek(0)
128
+
129
+ plt.close(fig) # ✅ CLOSE THE FIGURE TO FREE MEMORY
130
+
131
+ return Image.open(buf)
132
+
133
+ def detect_objects_in_video(video_input):
134
+ cap = cv2.VideoCapture(video_input)
135
+ if not cap.isOpened():
136
+ return "Error: Cannot open video file.", None # Returning a second value (None) to match expected outputs
137
+
138
+ frame_width, frame_height, fps = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FPS))
139
+ temp_video_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
140
+ out = cv2.VideoWriter(temp_video_output, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
141
+
142
+ # Initialize the counts for vehicle categories
143
+ total_counts = {name: 0 for name in ['car', 'truck', 'bus', 'motorcycle', 'bicycle']}
144
+
145
+ while cap.isOpened():
146
+ ret, frame = cap.read()
147
+ if not ret:
148
+ break
149
+
150
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
151
+
152
+ # Get frame with detected objects and graph
153
+ frame_with_boxes, graph_image = detect_objects_in_image(image)
154
+
155
+ # Convert image back to OpenCV format for writing video
156
+ out.write(cv2.cvtColor(np.array(frame_with_boxes), cv2.COLOR_RGB2BGR))
157
+
158
+ cap.release()
159
+ out.release()
160
+
161
+ return temp_video_output, graph_image # Return both expected outputs
162
 
163
  def greet(name):
164
  return "Hello " + name + "!!"