File size: 4,098 Bytes
fa0bd64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import glob
import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union
import cv2
import numpy as np
import PIL
import PIL.Image
import requests
from transformers import PretrainedConfig
# from llava.constants import MEDIA_TOKENS
# from llava.media import Image, Video
# from llava.utils import make_list
# from llava.utils.logging import logger
MEDIA_TOKENS = {
"image": "<image>",
"video": "<vila/video>",
}
class Media:
pass
class File(Media):
def __init__(self, path: str) -> None:
self.path = path
class Image(File):
pass
class Video(File):
pass
def make_list(obj: Any) -> List:
return obj if isinstance(obj, list) else [obj]
def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
if isinstance(image, Image):
if image.path.startswith("http://") or image.path.startswith("https://"):
image = PIL.Image.open(requests.get(image.path, stream=True).raw)
else:
image = PIL.Image.open(image.path)
return image
def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]:
# Load video frames from a directory
if os.path.isdir(video_path):
frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
return [PIL.Image.open(frame_paths[index]) for index in indices]
# Load video frames from a video file
vidcap = cv2.VideoCapture(video_path)
# Find the last frame as frame count might not be accurate
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
while frame_count > 0:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
if vidcap.grab():
break
frame_count -= 1
else:
raise ValueError(f"Video '{video_path}' has no frames.")
# Extract frames uniformly
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
frames = {}
for index in indices:
if index in frames:
continue
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
success, frame = vidcap.read()
if not success:
print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames[index] = PIL.Image.fromarray(frame)
return [frames[index] for index in indices if index in frames]
def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]:
num_frames = config.num_video_frames
if getattr(config, "fps") != 0:
print("Extracting frames from video with specified FPS is not supported yet. Ignored.")
frames = _load_video(video.path, num_frames=num_frames)
return frames
def extract_media(
messages: List[Dict[str, Any]],
config: Optional[PretrainedConfig] = None,
draft: bool = False,
) -> Dict[str, List[Any]]:
media = defaultdict(list)
for message in messages:
text = ""
for part in make_list(message["value"]):
if isinstance(part, str):
for token in MEDIA_TOKENS.values():
if token in part:
print(f"Media token '{token}' found in text: '{part}'. Removed.")
part = part.replace(token, "").strip()
text += part
elif isinstance(part, (Image, PIL.Image.Image)):
if draft:
media["image"].append(part)
else:
media["image"].append(_extract_image(part))
text += MEDIA_TOKENS["image"]
elif isinstance(part, Video):
if draft:
media["video"].append(part)
else:
media["video"].append(_extract_video(part, config))
text += MEDIA_TOKENS["video"]
else:
raise ValueError(f"Unsupported prompt part type: {type(part)}")
message["value"] = text
return media
|