import os import json import torch from torch import nn import torch.nn.functional as F from typing import List, Dict from transformers import AutoModel os.environ["TOKENIZERS_PARALLELISM"] = "false" class KananaEmbeddingWrapper(nn.Module): def __init__(self, model_name_or_path: str, trust_remote_code=True, device: str = "cpu", max_seq_length:int=None): """ Initialize the KananaEmbeddingWrapper. Args: model_name_or_path: Path or name of the pretrained model trust_remote_code: Whether to trust remote code when loading the model device: Device to load the model on (e.g., 'cpu', 'cuda') """ super(KananaEmbeddingWrapper, self).__init__() self.model_name_or_path = model_name_or_path self.trust_remote_code = trust_remote_code self.device = device self.kanana2vec = AutoModel.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code ).to(self.device) self.max_seq_length = max_seq_length if max_seq_length is not None else self.kanana2vec.config.max_position_embeddings def get_sentence_embedding_dimension(self) -> int: """ Returns the dimension of the sentence embeddings. Returns: Dimensionality of the sentence embeddings """ return self.kanana2vec.config.hidden_size def get_max_seq_length(self) -> int: """ Returns the maximum sequence length this module can process. Returns: Maximum sequence length """ return self.max_seq_length def tokenize(self, texts: List[str]) -> Dict[str, torch.Tensor]: """ Tokenize input texts. Args: texts: List of input texts to tokenize Returns: Dictionary containing tokenized inputs """ return self.kanana2vec.tokenizer( texts, padding=True, return_token_type_ids=False, return_tensors="pt", truncation=True, max_length=self.max_seq_length ).to(self.device) def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Forward pass of the module. Args: features: Dictionary with inputs including 'input_ids', 'attention_mask', etc. Returns: Dictionary with updated features including 'sentence_embedding' """ # Extract only the required features for the model model_inputs = self._extract_model_inputs(features) # Create pool mask considering prompt length if available model_inputs["pool_mask"] = self._create_pool_mask(features) # Get embeddings from the model and normalize embedding = self.kanana2vec.forward(**model_inputs).embedding normalized_embedding = F.normalize(embedding, p=2, dim=1) # Update features with sentence embedding features['sentence_embedding'] = normalized_embedding return features def _extract_model_inputs(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Extract only the inputs needed for the model. Args: features: Complete feature dictionary Returns: Dictionary with only the required keys for the model """ return {k: v for k, v in features.items() if k in ['input_ids', 'attention_mask']} def _create_pool_mask(self, features: Dict[str, torch.Tensor]) -> torch.Tensor: """ Create a pool mask based on attention mask and prompt length. Args: features: Feature dictionary containing attention_mask and optionally prompt_length Returns: Pool mask tensor """ pool_mask = features['attention_mask'].clone() if "prompt_length" in features: pool_mask[:, :features['prompt_length']] = 0 return pool_mask def get_config_dict(self) -> Dict: """ Returns a dictionary with the module's configuration. Returns: Dictionary with module configuration """ return { "model_name_or_path": self.model_name_or_path, "trust_remote_code": self.trust_remote_code, "device": self.device, "hidden_size": self.get_sentence_embedding_dimension(), "max_seq_length": self.get_max_seq_length() } def save(self, save_dir: str) -> None: """ Saves the module's configuration and model to the specified directory. Args: save_dir: Directory to save the module configuration """ os.makedirs(save_dir, exist_ok=True) # Save model configuration config_path = os.path.join(save_dir, "kanana_embedding_config.json") with open(config_path, 'w', encoding='utf-8') as f: json.dump(self.get_config_dict(), f, ensure_ascii=False, indent=2) # Save the underlying model model_save_path = os.path.join(save_dir, "kanana2vec") self.kanana2vec.save_pretrained(model_save_path) print(f"KananaEmbeddingWrapper model saved to {save_dir}") @staticmethod def load(load_dir: str, device: str = "cpu") -> 'KananaEmbeddingWrapper': """ Loads a KananaEmbeddingWrapper model from the specified directory. Args: load_dir: Directory containing the saved module device: Device to load the model on Returns: Initialized KananaEmbeddingWrapper """ # Load configuration config_path = os.path.join(load_dir, "kanana_embedding_config.json") with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) # Use the saved model path model_load_path = os.path.join(load_dir, "kanana2vec") # Create instance with saved configuration instance = KananaEmbeddingWrapper( model_name_or_path=model_load_path, trust_remote_code=config.get("trust_remote_code", True), device=device # Use the provided device or default ) return instance