supersolar commited on
Commit
c6e250d
·
verified ·
1 Parent(s): f5bc962

Update 1.py

Browse files
Files changed (1) hide show
  1. 1.py +117 -148
1.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  from typing import Tuple, Optional
3
  import shutil
@@ -9,6 +10,11 @@ import supervision as sv
9
  import torch
10
  from PIL import Image
11
  from tqdm import tqdm
 
 
 
 
 
12
  from utils.video import generate_unique_name, create_directory, delete_directory
13
  from utils.florence import load_florence_model, run_florence_inference, \
14
  FLORENCE_DETAILED_CAPTION_TASK, \
@@ -19,161 +25,124 @@ from utils.sam import load_sam_image_model, run_sam_inference, load_sam_video_mo
19
  DEVICE = torch.device("cuda")
20
  DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][-1]
21
  DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][0]
22
- # DEVICE = torch.device("cpu")
23
 
24
  torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
25
  if torch.cuda.get_device_properties(0).major >= 8:
26
  torch.backends.cuda.matmul.allow_tf32 = True
27
  torch.backends.cudnn.allow_tf32 = True
28
 
29
-
30
  FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
31
  SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
32
- SAM_VIDEO_MODEL = load_sam_video_model(device=DEVICE)
33
-
34
 
35
-
36
- texts = ['the table', 'all person','ball']
37
- from PIL import Image
38
- import supervision as sv
39
-
40
- def detect_objects_in_image(image_input_path, texts):
41
- # 加载图像
42
- image_input = Image.open(image_input_path)
43
-
44
- # 初始化检测列表
45
- detections_list = []
46
-
47
- # 对每个文本进行检测
48
- for text in texts:
49
- _, result = run_florence_inference(
50
- model=FLORENCE_MODEL,
51
- processor=FLORENCE_PROCESSOR,
52
- device=DEVICE,
53
- image=image_input,
54
- task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
55
- text=text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
57
-
58
- # 从结果中构建监督检测对象
59
- detections = sv.Detections.from_lmm(
60
- lmm=sv.LMM.FLORENCE_2,
61
- result=result,
62
- resolution_wh=image_input.size
63
- )
64
-
65
- # 运行 SAM 推理
66
- detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
67
-
68
- # 将检测结果添加到列表中
69
- detections_list.append(detections)
70
-
71
- # 合并所有检测结果
72
- detections = sv.Detections.merge(detections_list)
73
-
74
- # 再次运行 SAM 推理
75
- detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
76
-
77
- return detections
78
- # @title #合并遮罩加模糊merge_image_with_mask
79
- import numpy as np
80
- import cv2
81
- import os
82
- from PIL import Image, ImageFilter
83
-
84
- mask_folder = 'mask1'
85
- if not os.path.exists(mask_folder):
86
- os.makedirs(mask_folder)
87
- shutil.rmtree('mask1')
88
- mask_folder = 'mask1'
89
- if not os.path.exists(mask_folder):
90
- os.makedirs(mask_folder)
91
-
92
- def merge_image_with_mask(image_input_path, detections, output_folder):
93
- # 创建输出文件夹
94
- if not os.path.exists(output_folder):
95
- os.makedirs(output_folder)
96
-
97
- # 提取图片文件名
98
- image_name = os.path.basename(image_input_path)
99
- output_path = os.path.join(output_folder, image_name)
100
-
101
- # 创建掩码文件夹
102
- mask_folder = 'mask1'
103
-
104
-
105
- # 合并掩码
106
- combined_mask = np.zeros_like(detections.mask[0], dtype=np.uint8)
107
- for mask in detections.mask:
108
- combined_mask += mask
109
- combined_mask = np.clip(combined_mask, 0, 255)
110
- combined_mask = combined_mask.astype(np.uint8)
111
-
112
- # 膨胀掩码
113
- kernel = np.ones((6, 6), np.uint8)
114
- dilated_mask = cv2.dilate(combined_mask, kernel, iterations=1)
115
-
116
- # 保存膨胀后的掩码
117
- mask_path = os.path.join(mask_folder, image_name)
118
- cv2.imwrite(mask_path, dilated_mask * 255)
119
-
120
- # 读取原始图像
121
- original_image = cv2.imread(image_input_path)
122
-
123
- # 读取遮罩图片
124
- #mask_image = cv2.imread(mask_path)
125
-
126
- # 确保原始图片和遮罩图片尺寸一致
127
- #assert original_image.shape == mask_image.shape, "The images must have the same dimensions."
128
-
129
- # 使用掩膜从原始图片中提取部分区域
130
- masked_image = cv2.bitwise_and(original_image, original_image, mask=dilated_mask)
131
- # 将掩膜应用于原始图片
132
- #blurred_image = cv2.GaussianBlur(original_image, (21, 21), 500) # 使用较大的核大小进行模糊
133
- blurred_image = cv2.medianBlur(original_image, 21)
134
- # 将提取的部分区域叠加到模糊后的图片上
135
- blurred_image = cv2.bitwise_and(blurred_image, blurred_image, mask=~dilated_mask)
136
- # 将提取的部分区域叠加到模糊后的图片上
137
- result = np.where(dilated_mask[:, :, None] > 0, masked_image, blurred_image)
138
-
139
- # 保存合并后的图片
140
- cv2.imwrite(output_path, result)
141
- # @title #进度条批量处理文件夹process_images_in_folder(input_folder)
142
- from tqdm import tqdm
143
- import shutil
144
- def process_images_in_folder(input_folder):
145
- # 确保输出文件夹存在
146
- output_folder = 'okframe1'
147
- if not os.path.exists(output_folder):
148
- os.makedirs(output_folder)
149
- shutil.rmtree('okframe1')
150
- output_folder = 'okframe1'
151
- if not os.path.exists(output_folder):
152
- os.makedirs(output_folder)
153
-
154
- # 获取文件夹中的所有文件
155
- files = [f for f in os.listdir(input_folder) if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.jpeg')]
156
-
157
- # 使用 tqdm 显示进度条
158
- for filename in tqdm(files, desc="Gpu 1 Processing Images"):
159
- image_input_path = os.path.join(input_folder, filename)
160
-
161
- # 检测对象
162
- detections = detect_objects_in_image(
163
- image_input_path=image_input_path,
164
- texts=texts
165
- )
166
-
167
- # 合并图像
168
- merge_image_with_mask(
169
- image_input_path=image_input_path,
170
- detections=detections,
171
- output_folder=output_folder
172
- )
173
-
174
- # 使用示例
175
- input_folder = 'frame1'
176
- process_images_in_folder(input_folder)
177
-
178
-
179
-
 
1
+ #%cd /workspace/florence-samflorence-sam
2
  import os
3
  from typing import Tuple, Optional
4
  import shutil
 
10
  import torch
11
  from PIL import Image
12
  from tqdm import tqdm
13
+ import sys
14
+ import json
15
+ import pickle
16
+ os.chdir("/workspace/florence-sam")
17
+ sys.path.append('/workspace/florence-sam')
18
  from utils.video import generate_unique_name, create_directory, delete_directory
19
  from utils.florence import load_florence_model, run_florence_inference, \
20
  FLORENCE_DETAILED_CAPTION_TASK, \
 
25
  DEVICE = torch.device("cuda")
26
  DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][-1]
27
  DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][0]
 
28
 
29
  torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
30
  if torch.cuda.get_device_properties(0).major >= 8:
31
  torch.backends.cuda.matmul.allow_tf32 = True
32
  torch.backends.cudnn.allow_tf32 = True
33
 
 
34
  FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
35
  SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
 
 
36
 
37
+ with open('/workspace/florence-samtexts.pkl', 'rb') as file:
38
+ texts = pickle.load(file)
39
+ print(texts)
40
+
41
+ with open('/workspace/florence-samoutput_video.pkl', 'rb') as file:
42
+ output_video = pickle.load(file)
43
+ print(output_video)
44
+
45
+ VIDEO_SCALE_FACTOR = 1
46
+ VIDEO_TARGET_DIRECTORY = "/workspace/florence-sam"
47
+ create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
48
+
49
+
50
+ video_input= output_video
51
+ texts = ['the table', 'men','ball']
52
+ #VIDEO_TARGET_DIRECTORY = "/workspace/florence-sam"
53
+ if not video_input:
54
+ print("Please upload a video.")
55
+ '''
56
+ frame_generator = sv.get_video_frames_generator(video_input)
57
+ frame = next(frame_generator)
58
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
59
+ '''
60
+
61
+
62
+ frame_generator = sv.get_video_frames_generator(video_input)
63
+
64
+ # 获取视频的总帧数
65
+ total_frames = int(sv.get_total_frames(video_input))
66
+
67
+ # 计算中间帧的位置
68
+ middle_frame_index = total_frames // 2
69
+ with open('/workspace/florence-sam/middle_frame_index.pkl', 'wb') as file:
70
+ pickle.dump(middle_frame_index, file)
71
+ # 读取到中间帧
72
+ for _ in range(middle_frame_index):
73
+ frame = next(frame_generator)
74
+
75
+ # 将帧从 BGR 转换为 RGB 并保存到 PIL 图像对象
76
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
77
+
78
+ detections_list = []
79
+ width, height = frame.size
80
+ all_ok_bboxes = []
81
+ half_area = width * height * 0.5
82
+
83
+ # 存储所有 the table 的边界框和面积
84
+ table_bboxes = []
85
+ table_areas = []
86
+ given_area =1000
87
+ ok_result =[]
88
+ for text in texts:
89
+ _, result = run_florence_inference(
90
+ model=FLORENCE_MODEL,
91
+ processor=FLORENCE_PROCESSOR,
92
+ device=DEVICE,
93
+ image=frame,
94
+ task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
95
+ text=text )
96
+ #print(result)
97
+ for bbox, label in zip(result['<OPEN_VOCABULARY_DETECTION>']['bboxes'], result['<OPEN_VOCABULARY_DETECTION>']['bboxes_labels']):
98
+ print(bbox, label)
99
+ new_result = {'<OPEN_VOCABULARY_DETECTION>': {'bboxes': [bbox], 'bboxes_labels': [label], 'polygons': [], 'polygons_labels': []}}
100
+ print(new_result)
101
+ if label == 'ping pong ball':
102
+ # 计算当前 ping pong ball 的面积
103
+ area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
104
+ # 检查面积是否不超过给定边界框的面积
105
+ if area <= given_area:
106
+ all_ok_bboxes.append([[bbox[0], bbox[1]], [bbox[2], bbox[3]]])
107
+ ok_result.append(new_result)
108
+ elif label == 'the table':
109
+ # 计算当前 the table 的面积
110
+ print('the tablethe table!!!!')
111
+ area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
112
+ table_bboxes.append([[bbox[0] - 100, bbox[1]], [bbox[2] + 100, bbox[3]]])
113
+ table_areas.append(area)
114
+ elif label == 'table tennis bat':
115
+ all_ok_bboxes.append([[bbox[0], bbox[1]], [bbox[2], bbox[3]]])
116
+ ok_result.append(new_result)
117
+ elif label == 'men':
118
+ print('menmne!!!!')
119
+ all_ok_bboxes.append([[bbox[0], bbox[1]], [bbox[2], bbox[3]]])
120
+ ok_result.append(new_result)
121
+
122
+ # 找到面积最大的 the table
123
+ if table_areas:
124
+ max_area_index = table_areas.index(max(table_areas))
125
+ max_area_bbox = table_bboxes[max_area_index]
126
+
127
+ # 检查面积是否超过50%
128
+ if max(table_areas) < half_area:
129
+ all_ok_bboxes.append(max_area_bbox)
130
+ ok_result.append(new_result)
131
+
132
+ print(ok_result)
133
+ with open('/workspace/florence-samall_ok_bboxes.pkl', 'wb') as file:
134
+ pickle.dump(all_ok_bboxes, file)
135
+
136
+ for xyxy in ok_result:
137
+ print(frame.size,xyxy)
138
+ detections = sv.Detections.from_lmm(
139
+ lmm=sv.LMM.FLORENCE_2,
140
+ result=xyxy,
141
+ resolution_wh=frame.size
142
  )
143
+ detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections)
144
+ print(detections)
145
+ detections_list.append(detections)
146
+ with open('/workspace/florence-samdetections_list.pkl', 'wb') as file:
147
+ pickle.dump(detections_list, file)
148
+ print(detections_list)