kanana-nano-2.1b-embedding / kanana_embedding_wrapper.py
dongwook92
push sentrance_transformers compatible kanana
209fbaf
import os
import json
import torch
from torch import nn
import torch.nn.functional as F
from typing import List, Dict
from transformers import AutoModel
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class KananaEmbeddingWrapper(nn.Module):
def __init__(self, model_name_or_path: str, trust_remote_code=True, device: str = "cpu", max_seq_length:int=None):
"""
Initialize the KananaEmbeddingWrapper.
Args:
model_name_or_path: Path or name of the pretrained model
trust_remote_code: Whether to trust remote code when loading the model
device: Device to load the model on (e.g., 'cpu', 'cuda')
"""
super(KananaEmbeddingWrapper, self).__init__()
self.model_name_or_path = model_name_or_path
self.trust_remote_code = trust_remote_code
self.device = device
self.kanana2vec = AutoModel.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code
).to(self.device)
self.max_seq_length = max_seq_length if max_seq_length is not None else self.kanana2vec.config.max_position_embeddings
def get_sentence_embedding_dimension(self) -> int:
"""
Returns the dimension of the sentence embeddings.
Returns:
Dimensionality of the sentence embeddings
"""
return self.kanana2vec.config.hidden_size
def get_max_seq_length(self) -> int:
"""
Returns the maximum sequence length this module can process.
Returns:
Maximum sequence length
"""
return self.max_seq_length
def tokenize(self, texts: List[str]) -> Dict[str, torch.Tensor]:
"""
Tokenize input texts.
Args:
texts: List of input texts to tokenize
Returns:
Dictionary containing tokenized inputs
"""
return self.kanana2vec.tokenizer(
texts,
padding=True,
return_token_type_ids=False,
return_tensors="pt",
truncation=True,
max_length=self.max_seq_length
).to(self.device)
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Forward pass of the module.
Args:
features: Dictionary with inputs including 'input_ids', 'attention_mask', etc.
Returns:
Dictionary with updated features including 'sentence_embedding'
"""
# Extract only the required features for the model
model_inputs = self._extract_model_inputs(features)
# Create pool mask considering prompt length if available
model_inputs["pool_mask"] = self._create_pool_mask(features)
# Get embeddings from the model and normalize
embedding = self.kanana2vec.forward(**model_inputs).embedding
normalized_embedding = F.normalize(embedding, p=2, dim=1)
# Update features with sentence embedding
features['sentence_embedding'] = normalized_embedding
return features
def _extract_model_inputs(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Extract only the inputs needed for the model.
Args:
features: Complete feature dictionary
Returns:
Dictionary with only the required keys for the model
"""
return {k: v for k, v in features.items() if k in ['input_ids', 'attention_mask']}
def _create_pool_mask(self, features: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Create a pool mask based on attention mask and prompt length.
Args:
features: Feature dictionary containing attention_mask and optionally prompt_length
Returns:
Pool mask tensor
"""
pool_mask = features['attention_mask'].clone()
if "prompt_length" in features:
pool_mask[:, :features['prompt_length']] = 0
return pool_mask
def get_config_dict(self) -> Dict:
"""
Returns a dictionary with the module's configuration.
Returns:
Dictionary with module configuration
"""
return {
"model_name_or_path": self.model_name_or_path,
"trust_remote_code": self.trust_remote_code,
"device": self.device,
"hidden_size": self.get_sentence_embedding_dimension(),
"max_seq_length": self.get_max_seq_length()
}
def save(self, save_dir: str) -> None:
"""
Saves the module's configuration and model to the specified directory.
Args:
save_dir: Directory to save the module configuration
"""
os.makedirs(save_dir, exist_ok=True)
# Save model configuration
config_path = os.path.join(save_dir, "kanana_embedding_config.json")
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(self.get_config_dict(), f, ensure_ascii=False, indent=2)
# Save the underlying model
model_save_path = os.path.join(save_dir, "kanana2vec")
self.kanana2vec.save_pretrained(model_save_path)
print(f"KananaEmbeddingWrapper model saved to {save_dir}")
@staticmethod
def load(load_dir: str, device: str = "cpu") -> 'KananaEmbeddingWrapper':
"""
Loads a KananaEmbeddingWrapper model from the specified directory.
Args:
load_dir: Directory containing the saved module
device: Device to load the model on
Returns:
Initialized KananaEmbeddingWrapper
"""
# Load configuration
config_path = os.path.join(load_dir, "kanana_embedding_config.json")
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# Use the saved model path
model_load_path = os.path.join(load_dir, "kanana2vec")
# Create instance with saved configuration
instance = KananaEmbeddingWrapper(
model_name_or_path=model_load_path,
trust_remote_code=config.get("trust_remote_code", True),
device=device # Use the provided device or default
)
return instance