YANGYYYY commited on
Commit
459da92
·
verified ·
1 Parent(s): 8afe88f

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +38 -11
inference.py CHANGED
@@ -247,33 +247,60 @@ class Predictor:
247
  # if is_gg_drive:
248
  # temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
249
 
250
- def transform_and_save(self, frames, count):
251
- transformed_frames = []
 
 
 
 
 
 
252
  anime_images = self.transform(frames)
253
  for i in range(count):
254
  img = np.clip(anime_images[i], 0, 255).astype(np.uint8)
255
- transformed_frames.append(img)
256
- return transformed_frames
 
257
 
258
- frame_count = len(video_frames)
259
- transformed_video_frames = []
 
 
 
260
 
261
- batch_shape = (batch_size) + video_frames[0].shape
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  frames = np.zeros(batch_shape, dtype=np.uint8)
263
  frame_idx = 0
264
 
265
  try:
266
- for frame in video_frames:
 
 
 
267
  frames[frame_idx] = frame
268
  frame_idx += 1
269
  if frame_idx == batch_size:
270
- transformed_frames = transform_and_save(frames, frame_idx)
271
- transformed_video_frames.extend(transformed_frames)
272
  frame_idx = 0
273
  except Exception as e:
274
  print(e)
 
 
275
 
276
- return transformed_video_frames
277
 
278
 
279
 
 
247
  # if is_gg_drive:
248
  # temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
249
 
250
+ # def transform_and_save(self, frames, count):
251
+ # transformed_frames = []
252
+ # anime_images = self.transform(frames)
253
+ # for i in range(count):
254
+ # img = np.clip(anime_images[i], 0, 255).astype(np.uint8)
255
+ # transformed_frames.append(img)
256
+ # return transformed_frames
257
+ def transform_and_write(frames, count, video_buffer):
258
  anime_images = self.transform(frames)
259
  for i in range(count):
260
  img = np.clip(anime_images[i], 0, 255).astype(np.uint8)
261
+ success, encoded_image = cv2.imencode('.jpg', img)
262
+ if success:
263
+ video_buffer.append(encoded_image.tobytes())
264
 
265
+ video_capture = cv2.VideoCapture(video_frames)
266
+ frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
267
+ frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
268
+ fps = int(video_capture.get(cv2.CAP_PROP_FPS))
269
+ frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
270
 
271
+ if start or end:
272
+ start_frame = int(start * fps)
273
+ end_frame = int(end * fps) if end else frame_count
274
+ video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
275
+ frame_count = end_frame - start_frame
276
+
277
+ # frame_count = len(video_frames)
278
+ # transformed_video_frames = []
279
+ video_buffer = []
280
+
281
+ # batch_shape = (batch_size) + video_frames[0].shape
282
+ # frames = np.zeros(batch_shape, dtype=np.uint8)
283
+ # frame_idx = 0
284
+ _shape = (batch_size, frame_height, frame_width, 3)
285
  frames = np.zeros(batch_shape, dtype=np.uint8)
286
  frame_idx = 0
287
 
288
  try:
289
+ for _ in range(frame_count):
290
+ ret, frame = video_capture.read()
291
+ if not ret:
292
+ break
293
  frames[frame_idx] = frame
294
  frame_idx += 1
295
  if frame_idx == batch_size:
296
+ transform_and_write(frames, frame_idx, video_buffer)
 
297
  frame_idx = 0
298
  except Exception as e:
299
  print(e)
300
+ finally:
301
+ video_capture.release()
302
 
303
+ return video_buffer
304
 
305
 
306