SamDaLamb commited on
Commit
09624f9
·
verified ·
1 Parent(s): 8df620d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -34
app.py CHANGED
@@ -22,8 +22,6 @@ from io import BytesIO
22
  # Load the YOLO model
23
  from models.common import DetectMultiBackend
24
 
25
-
26
-
27
  weights_path = "./last.pt"
28
  device = torch.device("cpu") # Correctly define the device
29
  model = DetectMultiBackend(weights_path, device=device) # Load YOLOv5 model correctly
@@ -49,24 +47,23 @@ transform = transforms.Compose([
49
 
50
  OBJECT_NAMES = ['enemies']
51
 
52
- def detect_objects_in_image(image):
53
 
 
54
  # Ensure image is a PIL Image
 
55
  if isinstance(image, torch.Tensor):
56
  image = transforms.ToPILImage()(image) # Convert tensor to PIL image
57
-
58
- if isinstance(image, Image.Image):
59
  orig_w, orig_h = image.size # PIL image size returns (width, height)
60
  else:
61
- # raise TypeError(f"Expected a PIL Image but got ")
62
  print(type(image))
63
  exit()
64
 
65
  # Apply transformation
66
  img_tensor = transform(image).unsqueeze(0)
67
 
68
-
69
-
70
  with torch.no_grad():
71
  pred = model(img_tensor)[0]
72
 
@@ -106,7 +103,8 @@ def detect_objects_in_image(image):
106
  if object_name in object_counts:
107
  object_counts[object_name] += 1
108
  cv2.rectangle(img_array, (x1, y1), (x2, y2), (0, 255, 0), 2)
109
- cv2.putText(img_array, f"{object_name}: {scores[i]:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
 
110
 
111
  # Generate and return graph instead of dictionary
112
  graph_image = generate_vehicle_count_graph(object_counts)
@@ -116,58 +114,61 @@ def detect_objects_in_image(image):
116
 
117
  # def generate_vehicle_count_graph(object_counts):
118
  # color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1']
119
-
120
  # fig, ax = plt.subplots(figsize=(8, 5))
121
  # labels = list(object_counts.keys())
122
  # values = list(object_counts.values())
123
-
124
  # ax.bar(labels, values, color=color_palette[:len(labels)])
125
-
126
  # ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold')
127
  # ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold')
128
  # ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold')
129
-
130
  # plt.xticks(rotation=45, ha='right', fontsize=10)
131
  # plt.yticks(fontsize=10)
132
-
133
  # plt.tight_layout()
134
 
135
  # buf = BytesIO()
136
  # plt.savefig(buf, format='png')
137
  # buf.seek(0)
138
-
139
  # return Image.open(buf)
140
 
141
  def generate_vehicle_count_graph(object_counts):
142
- color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1']
143
-
 
144
  fig, ax = plt.subplots(figsize=(8, 5))
145
  labels = list(object_counts.keys())
146
  values = list(object_counts.values())
147
-
148
  ax.bar(labels, values, color=color_palette[:len(labels)])
149
  ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold')
150
  ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold')
151
  ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold')
152
-
153
  plt.xticks(rotation=45, ha='right', fontsize=10)
154
  plt.yticks(fontsize=10)
155
  plt.tight_layout()
156
-
157
  buf = BytesIO()
158
  plt.savefig(buf, format='png')
159
  buf.seek(0)
160
-
161
  plt.close(fig) # ✅ CLOSE THE FIGURE TO FREE MEMORY
162
-
163
  return Image.open(buf)
164
-
 
165
  def detect_objects_in_video(video_input):
166
  cap = cv2.VideoCapture(video_input)
167
  if not cap.isOpened():
168
  return "Error: Cannot open video file.", None # Returning a second value (None) to match expected outputs
169
 
170
- frame_width, frame_height, fps = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FPS))
 
171
  temp_video_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
172
  out = cv2.VideoWriter(temp_video_output, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
173
 
@@ -192,34 +193,37 @@ def detect_objects_in_video(video_input):
192
 
193
  return temp_video_output, graph_image # Return both expected outputs
194
 
 
195
  def greet(name):
196
  return "Hello " + name + "!!"
197
 
 
198
  demo = gr.Interface(fn=greet, inputs="text", outputs="text")
199
 
200
  from urllib.request import urlretrieve
201
 
202
  # get image examples from github
203
- urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-1450-_jpg.jpg?raw=true", "clip2_-1450-_jpg.jpg") # make sure to use "copy image address when copying image from Github"
204
- urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-539-_jpg.jpg?raw=true", "clip2_-539-_jpg.jpg")
205
- examples = [ # need to manually delete cache everytime new examples are added
206
- ["clip2_-1450-_jpg.jpg"],
 
 
207
  ["clip2_-539-_jpg.jpg"]]
208
 
209
-
210
  # define app features and run
211
  title = "SpecLab Demo"
212
  description = "<p style='text-align: center'>Gradio demo for an ASPP model architecture trained on the SpecLab dataset. To use it, simply add your image, or click one of the examples to load them. Since this demo is run on CPU only, please allow additional time for processing. </p>"
213
  article = "<p style='text-align: center'><a href='https://github.com/Nano1337/SpecLab'>Github Repo</a></p>"
214
  css = "#0 {object-fit: contain;} #1 {object-fit: contain;}"
215
- demo = gr.Interface(fn=detect_objects_in_image,
216
- title=title,
217
  description=description,
218
  article=article,
219
- inputs=gr.Image(elem_id=0, show_label=False),
220
  outputs=gr.Image(elem_id=1, show_label=False),
221
- css=css,
222
- examples=examples,
223
  cache_examples=True,
224
  allow_flagging='never')
225
  demo.launch()
 
22
  # Load the YOLO model
23
  from models.common import DetectMultiBackend
24
 
 
 
25
  weights_path = "./last.pt"
26
  device = torch.device("cpu") # Correctly define the device
27
  model = DetectMultiBackend(weights_path, device=device) # Load YOLOv5 model correctly
 
47
 
48
  OBJECT_NAMES = ['enemies']
49
 
 
50
 
51
+ def detect_objects_in_image(image):
52
  # Ensure image is a PIL Image
53
+ data = Image.fromarray(array)
54
  if isinstance(image, torch.Tensor):
55
  image = transforms.ToPILImage()(image) # Convert tensor to PIL image
56
+
57
+ if isinstance(data, Image.Image):
58
  orig_w, orig_h = image.size # PIL image size returns (width, height)
59
  else:
60
+ raise TypeError(f"Expected a PIL Image but got ")
61
  print(type(image))
62
  exit()
63
 
64
  # Apply transformation
65
  img_tensor = transform(image).unsqueeze(0)
66
 
 
 
67
  with torch.no_grad():
68
  pred = model(img_tensor)[0]
69
 
 
103
  if object_name in object_counts:
104
  object_counts[object_name] += 1
105
  cv2.rectangle(img_array, (x1, y1), (x2, y2), (0, 255, 0), 2)
106
+ cv2.putText(img_array, f"{object_name}: {scores[i]:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
107
+ (0, 255, 0), 2)
108
 
109
  # Generate and return graph instead of dictionary
110
  graph_image = generate_vehicle_count_graph(object_counts)
 
114
 
115
  # def generate_vehicle_count_graph(object_counts):
116
  # color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1']
117
+
118
  # fig, ax = plt.subplots(figsize=(8, 5))
119
  # labels = list(object_counts.keys())
120
  # values = list(object_counts.values())
121
+
122
  # ax.bar(labels, values, color=color_palette[:len(labels)])
123
+
124
  # ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold')
125
  # ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold')
126
  # ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold')
127
+
128
  # plt.xticks(rotation=45, ha='right', fontsize=10)
129
  # plt.yticks(fontsize=10)
130
+
131
  # plt.tight_layout()
132
 
133
  # buf = BytesIO()
134
  # plt.savefig(buf, format='png')
135
  # buf.seek(0)
136
+
137
  # return Image.open(buf)
138
 
139
  def generate_vehicle_count_graph(object_counts):
140
+ color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1',
141
+ '#B8B8D1']
142
+
143
  fig, ax = plt.subplots(figsize=(8, 5))
144
  labels = list(object_counts.keys())
145
  values = list(object_counts.values())
146
+
147
  ax.bar(labels, values, color=color_palette[:len(labels)])
148
  ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold')
149
  ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold')
150
  ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold')
151
+
152
  plt.xticks(rotation=45, ha='right', fontsize=10)
153
  plt.yticks(fontsize=10)
154
  plt.tight_layout()
155
+
156
  buf = BytesIO()
157
  plt.savefig(buf, format='png')
158
  buf.seek(0)
159
+
160
  plt.close(fig) # ✅ CLOSE THE FIGURE TO FREE MEMORY
161
+
162
  return Image.open(buf)
163
+
164
+
165
  def detect_objects_in_video(video_input):
166
  cap = cv2.VideoCapture(video_input)
167
  if not cap.isOpened():
168
  return "Error: Cannot open video file.", None # Returning a second value (None) to match expected outputs
169
 
170
+ frame_width, frame_height, fps = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(
171
+ cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FPS))
172
  temp_video_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
173
  out = cv2.VideoWriter(temp_video_output, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
174
 
 
193
 
194
  return temp_video_output, graph_image # Return both expected outputs
195
 
196
+
197
  def greet(name):
198
  return "Hello " + name + "!!"
199
 
200
+
201
  demo = gr.Interface(fn=greet, inputs="text", outputs="text")
202
 
203
  from urllib.request import urlretrieve
204
 
205
  # get image examples from github
206
+ urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-1450-_jpg.jpg?raw=true",
207
+ "clip2_-1450-_jpg.jpg") # make sure to use "copy image address when copying image from Github"
208
+ urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-539-_jpg.jpg?raw=true",
209
+ "clip2_-539-_jpg.jpg")
210
+ examples = [ # need to manually delete cache everytime new examples are added
211
+ ["clip2_-1450-_jpg.jpg"],
212
  ["clip2_-539-_jpg.jpg"]]
213
 
 
214
  # define app features and run
215
  title = "SpecLab Demo"
216
  description = "<p style='text-align: center'>Gradio demo for an ASPP model architecture trained on the SpecLab dataset. To use it, simply add your image, or click one of the examples to load them. Since this demo is run on CPU only, please allow additional time for processing. </p>"
217
  article = "<p style='text-align: center'><a href='https://github.com/Nano1337/SpecLab'>Github Repo</a></p>"
218
  css = "#0 {object-fit: contain;} #1 {object-fit: contain;}"
219
+ demo = gr.Interface(fn=detect_objects_in_image,
220
+ title=title,
221
  description=description,
222
  article=article,
223
+ inputs=gr.Image(elem_id=0, show_label=False),
224
  outputs=gr.Image(elem_id=1, show_label=False),
225
+ css=css,
226
+ examples=examples,
227
  cache_examples=True,
228
  allow_flagging='never')
229
  demo.launch()