import os import sys import argparse import logging from pathlib import Path from typing import List, Dict, Any, Optional import warnings import torch import torch.nn.functional as F import pandas as pd import numpy as np from tqdm import tqdm from datasets import Dataset, DatasetDict from transformers import AutoModel, AutoTokenizer warnings.filterwarnings('ignore') logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('embedding_generation.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) class AffiliationEmbedder: def __init__( self, model_path: str = "./affiliation-clustering-0.3b", device: str = None, batch_size: int = 32, max_length: int = 512, use_fp16: bool = False ): self.model_path = model_path self.batch_size = batch_size self.max_length = max_length self.use_fp16 = use_fp16 if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) logger.info(f"Using device: {self.device}") if self.device.type == 'cuda': logger.info(f"GPU: {torch.cuda.get_device_name()}") logger.info(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") self._load_model() def _load_model(self): logger.info(f"Loading model from {self.model_path}") try: self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, trust_remote_code=True ) self.model = AutoModel.from_pretrained( self.model_path, trust_remote_code=True ) self.model = self.model.to(self.device) if self.use_fp16 and self.device.type == 'cuda': self.model = self.model.half() logger.info("Using FP16 mixed precision") self.model.eval() logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") raise def encode_batch(self, texts: List[str]) -> np.ndarray: encoded = self.tokenizer( texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt' ) encoded = {k: v.to(self.device) for k, v in encoded.items()} with torch.no_grad(): outputs = self.model(**encoded) if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: embeddings = outputs.pooler_output else: token_embeddings = outputs.last_hidden_state attention_mask = encoded['attention_mask'].unsqueeze(-1) masked_embeddings = token_embeddings * attention_mask embeddings = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1) embeddings = F.normalize(embeddings, p=2, dim=1) embeddings = embeddings.cpu().numpy() if self.use_fp16: embeddings = embeddings.astype(np.float32) return embeddings def process_dataset( self, data_path: str, output_path: str, checkpoint_interval: int = 1000 ) -> None: logger.info(f"Processing dataset: {data_path}") df = pd.read_parquet(data_path) logger.info(f"Loaded {len(df)} samples") checkpoint_path = output_path.replace('.parquet', '_checkpoint.parquet') start_idx = 0 if os.path.exists(checkpoint_path): logger.info(f"Found checkpoint at {checkpoint_path}") checkpoint_df = pd.read_parquet(checkpoint_path) start_idx = len(checkpoint_df) logger.info(f"Resuming from index {start_idx}") all_embeddings = [] processed_rows = [] total_batches = (len(df) - start_idx + self.batch_size - 1) // self.batch_size with tqdm(total=total_batches, desc="Generating embeddings") as pbar: for i in range(start_idx, len(df), self.batch_size): batch_df = df.iloc[i:i+self.batch_size] texts = batch_df['affiliation_name'].tolist() try: batch_embeddings = self.encode_batch(texts) for j, embedding in enumerate(batch_embeddings): row_idx = i + j row_data = df.iloc[row_idx].to_dict() row_data['embedding'] = embedding processed_rows.append(row_data) if len(processed_rows) % checkpoint_interval == 0: self._save_checkpoint(processed_rows, checkpoint_path) logger.info(f"Checkpoint saved at {len(processed_rows)} samples") pbar.update(1) except Exception as e: logger.error(f"Error processing batch at index {i}: {e}") if processed_rows: self._save_checkpoint(processed_rows, checkpoint_path) raise result_df = pd.DataFrame(processed_rows) logger.info(f"Saving embeddings to {output_path}") result_df.to_parquet(output_path, compression='snappy') if os.path.exists(checkpoint_path): os.remove(checkpoint_path) logger.info("Checkpoint file removed") logger.info(f"Successfully generated embeddings for {len(result_df)} samples") embedding_dim = len(result_df['embedding'].iloc[0]) logger.info(f"Embedding dimension: {embedding_dim}") logger.info(f"Output file size: {os.path.getsize(output_path) / 1e6:.2f} MB") def _save_checkpoint(self, processed_rows: List[Dict], checkpoint_path: str): checkpoint_df = pd.DataFrame(processed_rows) checkpoint_df.to_parquet(checkpoint_path, compression='snappy') def main(): parser = argparse.ArgumentParser( description="Generate embeddings for affiliation strings" ) parser.add_argument( "--model-path", type=str, default="./affiliation-clustering-0.3b", help="Path to the pre-trained model directory" ) parser.add_argument( "--data-dir", type=str, default="./20250727-unique-openalex-affiliations-w-ror-ids-top-1K-ror-ids-100-per-sample", help="Directory containing the input parquet files" ) parser.add_argument( "--output-dir", type=str, default="./20250727-unique-openalex-affiliations-w-ror-ids-top-1K-ror-ids-100-per-sample-embeddings", help="Directory to save the output embeddings" ) parser.add_argument( "--batch-size", type=int, default=32, help="Batch size for processing" ) parser.add_argument( "--max-length", type=int, default=512, help="Maximum sequence length for tokenization" ) parser.add_argument( "--device", type=str, default=None, help="Device to use (cuda/cpu, auto-detect if not specified)" ) parser.add_argument( "--use-fp16", action="store_true", help="Use FP16 mixed precision for faster processing" ) parser.add_argument( "--checkpoint-interval", type=int, default=1000, help="Save checkpoint every N batches" ) parser.add_argument( "--push-to-hub", action="store_true", help="Push the resulting dataset to Hugging Face Hub" ) parser.add_argument( "--hub-dataset-id", type=str, default=None, help="Hugging Face Hub dataset ID (required if push-to-hub is set)" ) args = parser.parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) embedder = AffiliationEmbedder( model_path=args.model_path, device=args.device, batch_size=args.batch_size, max_length=args.max_length, use_fp16=args.use_fp16 ) data_dir = Path(args.data_dir) train_file = list(data_dir.glob("*_train.parquet"))[0] test_file = list(data_dir.glob("*_test.parquet"))[0] train_output = output_dir / "train_embeddings.parquet" test_output = output_dir / "test_embeddings.parquet" logger.info("Processing training dataset...") embedder.process_dataset( str(train_file), str(train_output), checkpoint_interval=args.checkpoint_interval ) logger.info("Processing test dataset...") embedder.process_dataset( str(test_file), str(test_output), checkpoint_interval=args.checkpoint_interval ) if args.push_to_hub: if not args.hub_dataset_id: logger.error("--hub-dataset-id is required when --push-to-hub is set") sys.exit(1) logger.info(f"Pushing dataset to Hugging Face Hub: {args.hub_dataset_id}") try: from huggingface_hub import HfApi, login token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN') if token: login(token=token) logger.info("Authenticated with Hugging Face Hub using token") else: logger.info("No HF token found in environment, attempting to use existing credentials") logger.info("Loading generated embeddings...") train_df = pd.read_parquet(train_output) test_df = pd.read_parquet(test_output) logger.info(f"Train dataset: {len(train_df)} samples") logger.info(f"Test dataset: {len(test_df)} samples") logger.info("Creating dataset dictionary...") dataset_dict = DatasetDict({ 'train': Dataset.from_pandas(train_df), 'test': Dataset.from_pandas(test_df) }) logger.info(f"Pushing to hub: {args.hub_dataset_id}") dataset_dict.push_to_hub( args.hub_dataset_id, private=False, commit_message="Add affiliation embeddings generated with affiliation-clustering-0.3b model" ) logger.info(f"Dataset successfully pushed to {args.hub_dataset_id}") logger.info(f"View at: https://huggingface.co/datasets/{args.hub_dataset_id}") except ImportError as e: logger.error(f"Failed to import required libraries: {e}") logger.error("Make sure huggingface_hub and datasets are installed") sys.exit(1) except Exception as e: logger.error(f"Failed to push dataset to hub: {e}") logger.error(f"Error type: {type(e).__name__}") import traceback logger.error(f"Traceback: {traceback.format_exc()}") sys.exit(1) logger.info("Embedding generation completed successfully!") if __name__ == "__main__": main()