|
from functools import partial |
|
from typing import Any, Dict, List, Optional |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
class BaseEncoder(nn.Module): |
|
def __init__(self, parent: nn.Module) -> None: |
|
super().__init__() |
|
self._parent = [parent] |
|
|
|
@property |
|
def parent(self) -> nn.Module: |
|
return self._parent[0] |
|
|
|
|
|
class BasicImageEncoder(BaseEncoder): |
|
def __init__( |
|
self, |
|
parent: torch.nn.Module, |
|
start_tokens: Optional[str] = None, |
|
end_tokens: Optional[str] = "\n", |
|
) -> None: |
|
super().__init__(parent) |
|
self.start_tokens = start_tokens |
|
self.end_tokens = end_tokens |
|
|
|
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: |
|
if tokens is None: |
|
return None |
|
token_ids = self.parent.tokenizer(tokens).input_ids |
|
token_ids = torch.tensor(token_ids, device=self.parent.device) |
|
return self.parent.llm.model.embed_tokens(token_ids) |
|
|
|
def _process_features( |
|
self, |
|
features: torch.Tensor, |
|
start_token_embeds: Optional[torch.Tensor], |
|
end_token_embeds: Optional[torch.Tensor], |
|
) -> torch.Tensor: |
|
if start_token_embeds is not None: |
|
features = torch.cat([start_token_embeds, features], dim=0) |
|
if end_token_embeds is not None: |
|
features = torch.cat([features, end_token_embeds], dim=0) |
|
return features |
|
|
|
def forward(self, images: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: |
|
images = torch.stack(images, dim=0) |
|
features = self.parent.encode_images(images, block_sizes=config.get("block_sizes")) |
|
process_features = partial( |
|
self._process_features, |
|
start_token_embeds=self.embed_tokens(self.start_tokens), |
|
end_token_embeds=self.embed_tokens(self.end_tokens), |
|
) |
|
return [process_features(f) for f in features] |
|
|
|
|
|
class BasicVideoEncoder(BaseEncoder): |
|
def __init__( |
|
self, |
|
parent: torch.nn.Module, |
|
start_tokens: Optional[str] = None, |
|
end_tokens: Optional[str] = "\n", |
|
) -> None: |
|
super().__init__(parent) |
|
self.start_tokens = start_tokens |
|
self.end_tokens = end_tokens |
|
|
|
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: |
|
if tokens is None: |
|
return None |
|
token_ids = self.parent.tokenizer(tokens).input_ids |
|
token_ids = torch.tensor(token_ids, device=self.parent.device) |
|
return self.parent.llm.model.embed_tokens(token_ids) |
|
|
|
def _process_features( |
|
self, |
|
features: torch.Tensor, |
|
start_token_embeds: Optional[torch.Tensor], |
|
end_token_embeds: Optional[torch.Tensor], |
|
) -> torch.Tensor: |
|
if start_token_embeds is not None: |
|
start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0) |
|
features = torch.cat([start_embeds, features], dim=1) |
|
if end_token_embeds is not None: |
|
end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0) |
|
features = torch.cat([features, end_embeds], dim=1) |
|
return features.flatten(0, 1) |
|
|
|
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: |
|
num_frames = [video.shape[0] for video in videos] |
|
images = torch.cat(videos, dim=0) |
|
features = self.parent.encode_images(images) |
|
features = torch.split(features, num_frames) |
|
process_features = partial( |
|
self._process_features, |
|
start_token_embeds=self.embed_tokens(self.start_tokens), |
|
end_token_embeds=self.embed_tokens(self.end_tokens), |
|
) |
|
return [process_features(f) for f in features] |
|
|