|
import os |
|
import json |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from typing import List, Dict |
|
|
|
from transformers import AutoModel |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
class KananaEmbeddingWrapper(nn.Module): |
|
def __init__(self, model_name_or_path: str, trust_remote_code=True, device: str = "cpu", max_seq_length:int=None): |
|
""" |
|
Initialize the KananaEmbeddingWrapper. |
|
|
|
Args: |
|
model_name_or_path: Path or name of the pretrained model |
|
trust_remote_code: Whether to trust remote code when loading the model |
|
device: Device to load the model on (e.g., 'cpu', 'cuda') |
|
""" |
|
super(KananaEmbeddingWrapper, self).__init__() |
|
self.model_name_or_path = model_name_or_path |
|
self.trust_remote_code = trust_remote_code |
|
self.device = device |
|
self.kanana2vec = AutoModel.from_pretrained( |
|
model_name_or_path, trust_remote_code=trust_remote_code |
|
).to(self.device) |
|
|
|
self.max_seq_length = max_seq_length if max_seq_length is not None else self.kanana2vec.config.max_position_embeddings |
|
|
|
def get_sentence_embedding_dimension(self) -> int: |
|
""" |
|
Returns the dimension of the sentence embeddings. |
|
|
|
Returns: |
|
Dimensionality of the sentence embeddings |
|
""" |
|
return self.kanana2vec.config.hidden_size |
|
|
|
def get_max_seq_length(self) -> int: |
|
""" |
|
Returns the maximum sequence length this module can process. |
|
|
|
Returns: |
|
Maximum sequence length |
|
""" |
|
return self.max_seq_length |
|
|
|
def tokenize(self, texts: List[str]) -> Dict[str, torch.Tensor]: |
|
""" |
|
Tokenize input texts. |
|
|
|
Args: |
|
texts: List of input texts to tokenize |
|
|
|
Returns: |
|
Dictionary containing tokenized inputs |
|
""" |
|
return self.kanana2vec.tokenizer( |
|
texts, |
|
padding=True, |
|
return_token_type_ids=False, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=self.max_seq_length |
|
).to(self.device) |
|
|
|
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
""" |
|
Forward pass of the module. |
|
|
|
Args: |
|
features: Dictionary with inputs including 'input_ids', 'attention_mask', etc. |
|
|
|
Returns: |
|
Dictionary with updated features including 'sentence_embedding' |
|
""" |
|
|
|
model_inputs = self._extract_model_inputs(features) |
|
|
|
|
|
model_inputs["pool_mask"] = self._create_pool_mask(features) |
|
|
|
|
|
embedding = self.kanana2vec.forward(**model_inputs).embedding |
|
normalized_embedding = F.normalize(embedding, p=2, dim=1) |
|
|
|
|
|
features['sentence_embedding'] = normalized_embedding |
|
return features |
|
|
|
def _extract_model_inputs(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
""" |
|
Extract only the inputs needed for the model. |
|
|
|
Args: |
|
features: Complete feature dictionary |
|
|
|
Returns: |
|
Dictionary with only the required keys for the model |
|
""" |
|
return {k: v for k, v in features.items() if k in ['input_ids', 'attention_mask']} |
|
|
|
def _create_pool_mask(self, features: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
""" |
|
Create a pool mask based on attention mask and prompt length. |
|
|
|
Args: |
|
features: Feature dictionary containing attention_mask and optionally prompt_length |
|
|
|
Returns: |
|
Pool mask tensor |
|
""" |
|
pool_mask = features['attention_mask'].clone() |
|
if "prompt_length" in features: |
|
pool_mask[:, :features['prompt_length']] = 0 |
|
return pool_mask |
|
|
|
def get_config_dict(self) -> Dict: |
|
""" |
|
Returns a dictionary with the module's configuration. |
|
|
|
Returns: |
|
Dictionary with module configuration |
|
""" |
|
return { |
|
"model_name_or_path": self.model_name_or_path, |
|
"trust_remote_code": self.trust_remote_code, |
|
"device": self.device, |
|
"hidden_size": self.get_sentence_embedding_dimension(), |
|
"max_seq_length": self.get_max_seq_length() |
|
} |
|
|
|
def save(self, save_dir: str) -> None: |
|
""" |
|
Saves the module's configuration and model to the specified directory. |
|
|
|
Args: |
|
save_dir: Directory to save the module configuration |
|
""" |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
config_path = os.path.join(save_dir, "kanana_embedding_config.json") |
|
with open(config_path, 'w', encoding='utf-8') as f: |
|
json.dump(self.get_config_dict(), f, ensure_ascii=False, indent=2) |
|
|
|
|
|
model_save_path = os.path.join(save_dir, "kanana2vec") |
|
self.kanana2vec.save_pretrained(model_save_path) |
|
|
|
print(f"KananaEmbeddingWrapper model saved to {save_dir}") |
|
|
|
@staticmethod |
|
def load(load_dir: str, device: str = "cpu") -> 'KananaEmbeddingWrapper': |
|
""" |
|
Loads a KananaEmbeddingWrapper model from the specified directory. |
|
|
|
Args: |
|
load_dir: Directory containing the saved module |
|
device: Device to load the model on |
|
|
|
Returns: |
|
Initialized KananaEmbeddingWrapper |
|
""" |
|
|
|
config_path = os.path.join(load_dir, "kanana_embedding_config.json") |
|
with open(config_path, 'r', encoding='utf-8') as f: |
|
config = json.load(f) |
|
|
|
|
|
model_load_path = os.path.join(load_dir, "kanana2vec") |
|
|
|
|
|
instance = KananaEmbeddingWrapper( |
|
model_name_or_path=model_load_path, |
|
trust_remote_code=config.get("trust_remote_code", True), |
|
device=device |
|
) |
|
|
|
return instance |