import itertools import json import pickle from random import Random from datasets import load_dataset import faiss import pandas as pd import numpy as np import torch from huggingface_hub import hf_hub_download from sentence_transformers import SentenceTransformer class GaussianCoveragePooling(torch.nn.Module): def __init__(self, coverage_chunks, sigma, alpha): """ Custom pooling layer that computes weighted mean pooling using Gaussian-based weights. Args: coverage_chunks (int): Number of weighted pooling operations (N). sigma (float): Standard deviation for Gaussian weighting. alpha (float): Weighting factor for merging with standard mean pooling. """ super().__init__() self.coverage_chunks = coverage_chunks self.sigma = sigma # Controls width of Gaussians self.alpha = alpha # Blends standard mean with weighted mean def forward(self, features, chunk_indicators=None): """ Computes weighted mean pooling using Gaussian-based weights. Args: self (SentenceTransformer): The model. features (dict): The token embeddings and attention mask. chunk_indicators (tensor[bz, 1]): Index indicators to return a specific chunk, leave as None to return embeddings for all chunks. Mainly useful for training, not inference. Leave as None for inference. """ # Get token embeddings and attention mask token_embeddings = features[ "token_embeddings" ] # (batch_size, seq_len, hidden_dim) attention_mask = ( features["attention_mask"].float().unsqueeze(-1) ) # (batch_size, seq_len, 1) # Get shapes and devices batch_size, seq_len, hidden_dim = token_embeddings.shape device = token_embeddings.device # Compute actual sequence lengths (ignoring padding) # (batch_size, 1) seq_lengths = attention_mask.squeeze(-1).sum(dim=1, keepdim=True) max_seq_length = int(torch.max(seq_lengths).item()) # Standard mean pooling sum_embeddings = torch.sum(token_embeddings * attention_mask, dim=1) sum_mask = torch.sum(attention_mask, dim=1).clamp(min=1e-9) standard_mean = sum_embeddings / sum_mask # (batch_size, hidden_dim) # Compute chunk centers dynamically based on sequence length chunk_positions = torch.linspace(0, 1, self.coverage_chunks + 2, device=device)[ 1:-1 ] # Excludes 0 and 1 chunk_centers = chunk_positions * seq_lengths # (batch_size, N) # Token positions per sequence (batch_size, seq_len) token_positions = ( torch.arange(seq_len, device=device).float().unsqueeze(0) ) # (1, seq_len) # Compute Gaussian weights (batch_size, N, seq_len) seq_lengths = seq_lengths.view(seq_lengths.shape[0], 1, 1).repeat( 1, self.coverage_chunks, max_seq_length ) gaussians = torch.exp( -0.5 * ( (token_positions.unsqueeze(1) - chunk_centers.unsqueeze(2)) / (self.sigma * seq_lengths) ) ** 2 ) # Mask out padding and normalize Gaussian weights per sequence # (batch_size, N, seq_len) gaussians = gaussians * attention_mask.squeeze(-1).unsqueeze(1) # Normalize against gaussian weights gaussians /= gaussians.sum(dim=2, keepdim=True).clamp(min=1e-9) # Compute weighted mean for each chunk (batch_size, N, hidden_dim) weighted_means = torch.einsum( "bns,bsh->bnh", gaussians.to(token_embeddings.dtype), token_embeddings ) # Blend with standard mean pooling # (batch_size, N, hidden_dim) combined_embeddings = (1 - self.alpha) * standard_mean.unsqueeze( 1 ) + self.alpha * weighted_means # Add an embedding for the entire document at index 0 # (batch_size, N+1, hidden_dim) combined_embeddings = torch.cat( [torch.zeros_like(combined_embeddings[:, :1]), combined_embeddings], 1 ) combined_embeddings[:, 0:1, :] = standard_mean.unsqueeze(1) # Select the indicator if provided if chunk_indicators is not None: combined_embeddings = combined_embeddings[ torch.arange(combined_embeddings.size(0)), chunk_indicators ] # Normalize all the embeddings combined_embeddings = torch.nn.functional.normalize( combined_embeddings, p=2, dim=-1 ) # Flatten final embeddings (batch_size, hidden_dim * (N+1)) if chunk_indicators is None: sentence_embedding = combined_embeddings.reshape( batch_size, hidden_dim * (self.coverage_chunks + 1) ) else: sentence_embedding = combined_embeddings # Return the final flattened entence embedding features["sentence_embedding"] = sentence_embedding return features def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0): """ Add custom pooling layer that computes weighted mean pooling using Gaussian-based weights. Args: m (SentenceTransformer): The model to add pooling layer to. coverage_chunks (int): Number of weighted pooling operations (N). sigma (float): Standard deviation for Gaussian weighting. alpha (float): Weighting factor for merging with standard mean pooling. """ if isinstance(m[1], GaussianCoveragePooling): m = unuse_gaussian_coverage_pooling(m) word_embedding_model = m[0] custom_pooling = GaussianCoveragePooling( coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha ) old_pooling = m[1] new_m = m.__class__(modules=[word_embedding_model, custom_pooling]) new_m.old_pooling = {"old_pooling": old_pooling} return new_m def unuse_gaussian_coverage_pooling(m): """ Removes the custom pooling layer. Args: m (SentenceTransformer): The model to remove the pooling layer from. """ if isinstance(m[1], GaussianCoveragePooling): new_m = m.__class__(modules=[m[0], m.old_pooling["old_pooling"]]) return new_m else: return m class InstructionTemplateRetriever: FINETEMPLATES_REVISION = "da7b725f7c1248f8c632d7d53c94109ae5dadbfe" RETRIEVAL_EMBEDDING_NAME = ( "fineinstructions/instruction_template_retrieval_embedding" ) RETRIEVAL_EMBEDDING_REVISION = "db4efbde126216250ffa5a356663fc7da3bf7856" def __init__( self, coverage_chunks=10, sigma=0.05, alpha=1.0, nprobe=150, ): """ Computes embeddings that cover a document to find relevant instruction templates using Gaussian-weighted embeddings that cover different parts of the document. Args: coverage_chunks (int): The number of equally sized chunks/sections to get coverage over the entire document. sigma (float): Standard deviation for Gaussian weighting, this will essentially control how "wide" / "focused" each chunk is. alpha (float): A weighting factor to control how much to balance the representation of a single chunk, versus the representation of the entire document. nprobe (int): The number of probes to use when searching the FAISS index (larger is more accurate, but slower). """ self.d = load_dataset( "fineinstructions/finetemplates", revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION, split="full", ) self.m = SentenceTransformer( InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_NAME, revision=InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_REVISION, device="cpu", ) self.m = use_gaussian_coverage_pooling( self.m, coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha ) self.index = faiss.read_index( hf_hub_download( "fineinstructions/finetemplates", "faiss_index/finetemplates.index", revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION, repo_type="dataset", ), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY, ) self.index.nprobe = nprobe if torch.cuda.is_available(): self.m = self.m.to("cuda") elif torch.backends.mps.is_available(): self.m = self.m.to("mps") with open( hf_hub_download( "fineinstructions/finetemplates", "faiss_index/reweighting_stats.pkl", revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION, repo_type="dataset", ), "rb", ) as reweighting_stats_fp: reweighting_stats = pickle.load(reweighting_stats_fp) self.resampling_weights = reweighting_stats["resampling_weights"] self.template_variable_count_mapping = reweighting_stats[ "template_variable_count_mapping" ] def _filter_rows(self, rows, filter_string): if not rows: return [] df = pd.DataFrame(rows) try: filtered_df = df.query(filter_string) return filtered_df.to_dict(orient="records") except Exception as e: return rows def search( self, document, filters="", search_k=20000, max_results=250, deduplicate=True, reweight=False, reweight_epsilon=0.05, ): """ Given a document Args: document (str): The document to retrieve relevant instruction templates for. filters (str): A query string in the format of pandas.DataFrame.query() search_k (int): The number of search results to pull when retrieving from FAISS. max_results (int): The max number of results to return. deduplicate (bool): Deduplicate results between coverage sections. reweight (bool): Whether to reweight the results based on a more realistic length distribution. reweight_epsilon (float): How tolerant to be when reweighting (larger is more inaccurate results but better reweighting) """ def _reweight(inp, k=None): if reweight: inp0, inp = itertools.tee(inp) first_row = next(inp0) r = Random(first_row[1].item()) epsilon = reweight_epsilon bucket = first_row[1] items = [] weights = [] for i, s in inp: if abs(bucket - s.item()) <= epsilon: items.append((i, s)) weights.append( self.resampling_weights[ self.template_variable_count_mapping[i.item()] ] ) else: break return r.choices( items, weights=weights, k=(len(items) if k is None else k) ) else: return inp # Search FAISS index vecs = self.m.encode([document], normalize_embeddings=False).reshape( -1, self.m[0].auto_model.config.hidden_size ) scores_batch, indices_batch = self.index.search(np.vstack(vecs), k=search_k) # Pull in FineTemplates rows into memory to_select = [i.item() for i in itertools.chain.from_iterable(indices_batch)] d_in_mem = { i: row for i, row in zip(to_select, self.d.select(to_select).to_list()) } # Group by coverage chunk true_coverage_chunks = self.m[1].coverage_chunks + 1 scores_per_input, indices_per_input = ( [ scores_batch[i : i + true_coverage_chunks] for i in range(0, len(scores_batch), true_coverage_chunks) ], [ indices_batch[i : i + true_coverage_chunks] for i in range(0, len(indices_batch), true_coverage_chunks) ], ) # Get the results for the first result in the batch (assuming bz=1) scores_per_input, indices_per_input = scores_per_input[0], indices_per_input[0] # Create result rows rows = [ [ { "coverage_section": f"{chunk_idx}/{self.m[1].coverage_chunks}" if chunk_idx > 0 else "Entire Document", "score": s.item(), **d_in_mem[i.item()], } for i, s in _reweight(zip(indices, scores), k=None) ] for chunk_idx, (indices, scores) in enumerate( zip(indices_per_input, scores_per_input) ) ] # Deduplicate if deduplicate: seen = set() rows = [ r for r in itertools.chain.from_iterable(zip(*rows)) if (len(seen) != len(seen.add(r["template_id"]) or seen)) ] else: rows = list(itertools.chain.from_iterable(zip(*rows))) # Filter rows = self._filter_rows(rows, filters)[:max_results] # Return rows return rows