Spaces:
Running
Running
# Copyright (2025) [Seed-VL-Cookbook] Bytedance Seed | |
import cv2 | |
import json | |
import time | |
import math | |
import base64 | |
import requests | |
import torch | |
import decord | |
import numpy as np | |
from PIL import Image, ImageSequence | |
from torchvision.io import read_image, encode_jpeg | |
from torchvision.transforms.functional import resize, pil_to_tensor | |
from torchvision.transforms import InterpolationMode | |
class ConversationModeI18N: | |
G = "General" | |
D = "Deep Thinking" | |
class ConversationModeCN: | |
G = "常规" | |
D = "深度思考" | |
def round_by_factor(number: int, factor: int) -> int: | |
"""Returns the closest integer to 'number' that is divisible by 'factor'.""" | |
return round(number / factor) * factor | |
def ceil_by_factor(number: int, factor: int) -> int: | |
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" | |
return math.ceil(number / factor) * factor | |
def floor_by_factor(number: int, factor: int) -> int: | |
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" | |
return math.floor(number / factor) * factor | |
def get_resized_hw_for_Navit( | |
height: int, | |
width: int, | |
min_pixels: int, | |
max_pixels: int, | |
max_ratio: int = 200, | |
factor: int = 28, | |
): | |
if max(height, width) / min(height, width) > max_ratio: | |
raise ValueError( | |
f"absolute aspect ratio must be smaller than {max_ratio}, got {max(height, width) / min(height, width)}" | |
) | |
h_bar = max(factor, round_by_factor(height, factor)) | |
w_bar = max(factor, round_by_factor(width, factor)) | |
if h_bar * w_bar > max_pixels: | |
beta = math.sqrt((height * width) / max_pixels) | |
h_bar = floor_by_factor(height / beta, factor) | |
w_bar = floor_by_factor(width / beta, factor) | |
elif h_bar * w_bar < min_pixels: | |
beta = math.sqrt(min_pixels / (height * width)) | |
h_bar = ceil_by_factor(height * beta, factor) | |
w_bar = ceil_by_factor(width * beta, factor) | |
return int(h_bar), int(w_bar) | |
class SeedVLInfer: | |
def __init__( | |
self, | |
model_id: str, | |
api_key: str, | |
base_url: str = 'https://ark.cn-beijing.volces.com/api/v3/chat/completions', | |
min_pixels: int = 4 * 28 * 28, | |
max_pixels: int = 5120 * 28 * 28, | |
video_sampling_strategy: dict = { | |
'sampling_fps': | |
1, | |
'min_n_frames': | |
16, | |
'max_video_length': | |
81920, | |
'max_pixels_choices': [ | |
640 * 28 * 28, 512 * 28 * 28, 384 * 28 * 28, 256 * 28 * 28, | |
160 * 28 * 28, 128 * 28 * 28 | |
], | |
'use_timestamp': | |
True, | |
}, | |
): | |
self.base_url = base_url | |
self.api_key = api_key | |
self.model_id = model_id | |
self.min_pixels = min_pixels | |
self.max_pixels = max_pixels | |
self.sampling_fps = video_sampling_strategy.get('sampling_fps', 1) | |
self.min_n_frames = video_sampling_strategy.get('min_n_frames', 16) | |
self.max_video_length = video_sampling_strategy.get( | |
'max_video_length', 81920) | |
self.max_pixels_choices = video_sampling_strategy.get( | |
'max_pixels_choices', [ | |
640 * 28 * 28, 512 * 28 * 28, 384 * 28 * 28, 256 * 28 * 28, | |
160 * 28 * 28, 128 * 28 * 28 | |
]) | |
self.use_timestamp = video_sampling_strategy.get('use_timestamp', True) | |
def preprocess_video(self, video_path: str): | |
try: | |
video_reader = decord.VideoReader(video_path, num_threads=2) | |
fps = video_reader.get_avg_fps() | |
except decord._ffi.base.DECORDError: | |
video_reader = [ | |
frame.convert('RGB') | |
for frame in ImageSequence.Iterator(Image.open(video_path)) | |
] | |
fps = 1 | |
length = len(video_reader) | |
n_frames = min( | |
max(math.ceil(length / fps * self.sampling_fps), | |
self.min_n_frames), length) | |
frame_indices = np.linspace(0, length - 1, | |
n_frames).round().astype(int).tolist() | |
max_pixels = self.max_pixels | |
for round_idx, max_pixels in enumerate(self.max_pixels_choices): | |
is_last_round = round_idx == len(self.max_pixels_choices) - 1 | |
if len(frame_indices | |
) * max_pixels / 28 / 28 > self.max_video_length: | |
if is_last_round: | |
max_frame_num = int(self.max_video_length / max_pixels * | |
28 * 28) | |
select_ids = np.linspace( | |
0, | |
len(frame_indices) - 1, | |
max_frame_num).round().astype(int).tolist() | |
frame_indices = [ | |
frame_indices[select_id] for select_id in select_ids | |
] | |
else: | |
continue | |
else: | |
break | |
if hasattr(video_reader, "get_batch"): | |
video_clip = torch.from_numpy( | |
video_reader.get_batch(frame_indices).asnumpy()).permute( | |
0, 3, 1, 2) | |
else: | |
video_clip_array = torch.stack( | |
[np.array(video_reader[i]) for i in frame_indices], dim=0) | |
video_clip = torch.from_numpy(video_clip_array).permute(0, 3, 1, 2) | |
height, width = video_clip.shape[-2:] | |
resized_height, resized_width = get_resized_hw_for_Navit( | |
height, | |
width, | |
min_pixels=self.min_pixels, | |
max_pixels=max_pixels, | |
) | |
resized_video_clip = resize(video_clip, | |
(resized_height, resized_width), | |
interpolation=InterpolationMode.BICUBIC, | |
antialias=True) | |
if self.use_timestamp: | |
resized_video_clip = [ | |
(round(i / fps, 1), f) | |
for i, f in zip(frame_indices, resized_video_clip) | |
] | |
return resized_video_clip | |
def preprocess_streaming_frame(self, frame: torch.Tensor): | |
height, width = frame.shape[-2:] | |
resized_height, resized_width = get_resized_hw_for_Navit( | |
height, | |
width, | |
min_pixels=self.min_pixels, | |
max_pixels=self.max_pixels_choices[0], | |
) | |
resized_frame = resize(frame[None], (resized_height, resized_width), | |
interpolation=InterpolationMode.BICUBIC, | |
antialias=True)[0] | |
return resized_frame | |
def encode_image(self, image: torch.Tensor) -> str: | |
if image.shape[0] == 4: | |
image = image[:3] | |
encoded = encode_jpeg(image) | |
return base64.b64encode(encoded.numpy()).decode('utf-8') | |
def construct_messages(self, | |
inputs: dict, | |
streaming_timestamp: int = None, | |
online: bool = False) -> list[dict]: | |
content = [] | |
for i, path in enumerate(inputs.get('files', [])): | |
if path.endswith('.mp4'): | |
video = self.preprocess_video(video_path=path) | |
for frame in video: | |
if self.use_timestamp: | |
timestamp, frame = frame | |
content.append({ | |
"type": "text", | |
"text": f'[{timestamp} second]', | |
}) | |
content.append({ | |
"type": "image_url", | |
"image_url": { | |
"url": | |
f"data:image/jpeg;base64,{self.encode_image(frame)}", | |
"detail": "high" | |
}, | |
}) | |
else: | |
try: | |
image = read_image(path, "RGB") | |
except: | |
try: | |
image = pil_to_tensor(Image.open(path).convert('RGB')) | |
except: | |
image = torch.from_numpy( | |
cv2.cvtColor( | |
cv2.imread(path), | |
cv2.COLOR_BGR2RGB | |
) | |
).permute(2, 0, 1) | |
if online and path.endswith('.webp'): | |
streaming_timestamp = i | |
if streaming_timestamp is not None: | |
image = self.preprocess_streaming_frame(frame=image) | |
content.append({ | |
"type": "image_url", | |
"image_url": { | |
"url": | |
f"data:image/jpeg;base64,{self.encode_image(image)}", | |
"detail": "high" | |
}, | |
}) | |
if streaming_timestamp is not None: | |
content.insert(-1, { | |
"type": "text", | |
"text": f'[{streaming_timestamp} second]', | |
}) | |
query = inputs.get('text', '') | |
if query: | |
content.append({ | |
"type": "text", | |
"text": query, | |
}) | |
messages = [{ | |
"role": "user", | |
"content": content, | |
}] | |
return messages | |
def request(self, | |
messages, | |
thinking: bool = True, | |
temperature: float = 1.0): | |
headers = { | |
"Authorization": f"Bearer {self.api_key}", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"model": self.model_id, | |
"messages": messages, | |
"stream": True, | |
"thinking": { | |
"type": "enabled" if thinking else "disabled", | |
}, | |
"temperature": temperature, | |
} | |
for _ in range(10): | |
try: | |
requested = requests.post(self.base_url, | |
headers=headers, | |
json=payload, | |
stream=True, | |
timeout=600) | |
break | |
except Exception as e: | |
time.sleep(0.1) | |
print(e) | |
content, reasoning_content = '', '' | |
for line in requested.iter_lines(): | |
if not line: | |
continue | |
if line.startswith(b'data:'): | |
data = line[len("data: "):] | |
if data == b"[DONE]": | |
yield content, reasoning_content, True | |
break | |
delta = json.loads(data)['choices'][0]['delta'] | |
content += delta['content'] | |
reasoning_content += delta.get('reasoning_content', '') | |
yield content, reasoning_content, False | |
def __call__(self, | |
inputs: dict, | |
history: list[dict] = [], | |
mode: str = ConversationModeI18N.D, | |
temperature: float = 1.0, | |
online: bool = False): | |
messages = self.construct_messages(inputs=inputs, online=online) | |
updated_history = history + messages | |
for response, reasoning, finished in self.request( | |
messages=updated_history, | |
thinking=mode == ConversationModeI18N.D, | |
temperature=temperature): | |
if mode == ConversationModeI18N.D: | |
response = '<think>' + reasoning + '</think>' + response | |
yield response, updated_history + [{'role': 'assistant', 'content': response}], finished |