import gradio as gr from transformers import AutoTokenizer, AutoModelForMaskedLM import torch import numpy as np from tqdm.auto import tqdm import os import ir_datasets import random # Added for random selection # --- Model Loading (Keep as is) --- tokenizer_splade = None model_splade = None tokenizer_splade_lexical = None model_splade_lexical = None tokenizer_splade_doc = None model_splade_doc = None # Load SPLADE v3 model (original) try: tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil") model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil") model_splade.eval() print("SPLADE-cocondenser-distil model loaded successfully!") except Exception as e: print(f"Error loading SPLADE-cocondenser-distil model: {e}") print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.") # Load SPLADE v3 Lexical model try: splade_lexical_model_name = "naver/splade-v3-lexical" tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name) model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name) model_splade_lexical.eval() print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!") except Exception as e: print(f"Error loading SPLADE-v3-Lexical model: {e}") print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).") # Load SPLADE v3 Doc model try: splade_doc_model_name = "naver/splade-v3-doc" tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name) model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name) model_splade_doc.eval() print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!") except Exception as e: print(f"Error loading SPLADE-v3-Doc model: {e}") print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).") # --- Global Variables for Document Index and Qrels --- document_representations = {} # Stores {doc_id: sparse_vector} document_texts = {} # Stores {doc_id: doc_text} queries_texts = {} # Stores {query_id: query_text} qrels_data = {} # Stores {query_id: [{doc_id: str, relevance: int}, ...]} initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index # --- Load Cranfield Corpus, Queries, and Qrels using ir_datasets --- def load_cranfield_corpus_ir_datasets(): global document_texts, queries_texts, qrels_data print("Loading Cranfield corpus, queries, and qrels using ir_datasets...") try: dataset = ir_datasets.load("cranfield") # Load documents for doc in tqdm(dataset.docs_iter(), desc="Loading Cranfield documents"): document_texts[doc.doc_id] = doc.text.strip() print(f"Loaded {len(document_texts)} documents from Cranfield corpus.") # Load queries for query in tqdm(dataset.queries_iter(), desc="Loading Cranfield queries"): queries_texts[query.query_id] = query.text.strip() print(f"Loaded {len(queries_texts)} queries from Cranfield corpus.") # Load qrels for qrel in tqdm(dataset.qrels_iter(), desc="Loading Cranfield qrels"): if qrel.query_id not in qrels_data: qrels_data[qrel.query_id] = [] qrels_data[qrel.query_id].append({"doc_id": qrel.doc_id, "relevance": qrel.relevance}) print(f"Loaded qrels for {len(qrels_data)} queries.") except Exception as e: print(f"Error loading Cranfield corpus with ir_datasets: {e}") print("Please ensure 'ir_datasets' is installed and your internet connection is stable.") # --- Helper function for lexical mask (now handles batches) --- def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer): """ Creates a batch of lexical BOW masks. input_ids_batch: torch.Tensor of shape (batch_size, sequence_length) vocab_size: int, size of the tokenizer vocabulary tokenizer: the tokenizer object Returns: torch.Tensor of shape (batch_size, vocab_size) """ batch_size = input_ids_batch.shape[0] bow_masks = torch.zeros(batch_size, vocab_size, device=input_ids_batch.device) for i in range(batch_size): input_ids = input_ids_batch[i] # Get input_ids for the current item in the batch meaningful_token_ids = [] for token_id in input_ids.tolist(): if token_id not in [ tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.mask_token_id, tokenizer.unk_token_id ]: meaningful_token_ids.append(token_id) if meaningful_token_ids: # Apply mask to the current row in the batch bow_masks[i, list(set(meaningful_token_ids))] = 1 return bow_masks # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) --- # These functions still take single text input for the Explorer tab def get_splade_cocondenser_representation(text): if tokenizer_splade is None or model_splade is None: return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors." inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model_splade.device) for k, v in inputs.items()} with torch.no_grad(): output = model_splade(**inputs) if hasattr(output, 'logits'): splade_vector = torch.max( torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1 )[0].squeeze() # Squeeze is fine here as it's a single input else: return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found." indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() if not isinstance(indices, list): indices = [indices] if indices else [] values = splade_vector[indices].cpu().tolist() token_weights = dict(zip(indices, values)) meaningful_tokens = {} for token_id, weight in token_weights.items(): decoded_token = tokenizer_splade.decode([token_id]) if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: meaningful_tokens[decoded_token] = weight sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) formatted_output = "SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n" if not sorted_representation: formatted_output += "No significant terms found for this input.\n" else: for term, weight in sorted_representation: formatted_output += f"- **{term}**: {weight:.4f}\n" formatted_output += "\n--- Raw SPLADE Vector Info ---\n" formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n" return formatted_output def get_splade_lexical_representation(text): if tokenizer_splade_lexical is None or model_splade_lexical is None: return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors." inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()} with torch.no_grad(): output = model_splade_lexical(**inputs) if hasattr(output, 'logits'): splade_vector = torch.max( torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1 )[0].squeeze() # Squeeze is fine here else: return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found." # Always apply lexical mask for this model's specific behavior vocab_size = tokenizer_splade_lexical.vocab_size # Call with unsqueezed input_ids for single sample processing bow_mask = create_lexical_bow_mask( inputs['input_ids'], vocab_size, tokenizer_splade_lexical ).squeeze() # Squeeze back for single output splade_vector = splade_vector * bow_mask indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() if not isinstance(indices, list): indices = [indices] if indices else [] values = splade_vector[indices].cpu().tolist() token_weights = dict(zip(indices, values)) meaningful_tokens = {} for token_id, weight in token_weights.items(): decoded_token = tokenizer_splade_lexical.decode([token_id]) if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: meaningful_tokens[decoded_token] = weight sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) formatted_output = "SPLADE-v3-Lexical Representation (Weighting):\n" if not sorted_representation: formatted_output += "No significant terms found for this input.\n" else: for term, weight in sorted_representation: formatted_output += f"- **{term}**: {weight:.4f}\n" formatted_output += "\n--- Raw SPLADE Vector Info ---\n" formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_lexical.vocab_size):.2%}\n" return formatted_output def get_splade_doc_representation(text): if tokenizer_splade_doc is None or model_splade_doc is None: return "SPLADE-v3-Doc model is not loaded. Please check the console for loading errors." inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()} with torch.no_grad(): output = model_splade_doc(**inputs) if not hasattr(output, "logits"): return "Model output structure not as expected. 'logits' not found." vocab_size = tokenizer_splade_doc.vocab_size # Call with unsqueezed input_ids for single sample processing binary_splade_vector = create_lexical_bow_mask( inputs['input_ids'], vocab_size, tokenizer_splade_doc ).squeeze() # Squeeze back for single output indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist() if not isinstance(indices, list): indices = [indices] if indices else [] values = [1.0] * len(indices) # All values are 1 for binary representation token_weights = dict(zip(indices, values)) meaningful_tokens = {} for token_id, weight in token_weights.items(): decoded_token = tokenizer_splade_doc.decode([token_id]) if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: meaningful_tokens[decoded_token] = weight sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity formatted_output = "SPLADE-v3-Doc Representation (Binary):\n" if not sorted_representation: formatted_output += "No significant terms found for this input.\n" else: for i, (term, _) in enumerate(sorted_representation): if i >= 50: # Limit display for very long lists formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n" break formatted_output += f"- **{term}**\n" formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n" formatted_output += f"Total activated terms: {len(indices)}\n" formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n" return formatted_output # --- Unified Prediction Function for the Explorer Tab --- def predict_representation_explorer(model_choice, text): if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)": return get_splade_cocondenser_representation(text) elif model_choice == "SPLADE-v3-Lexical (weighting)": return get_splade_lexical_representation(text) elif model_choice == "SPLADE-v3-Doc (binary)": return get_splade_doc_representation(text) else: return "Please select a model." # --- Internal Core Representation Functions (now handle batches) --- def get_splade_cocondenser_representation_internal(texts, tokenizer, model): """ Generates SPLADE representations for a batch of texts. texts: list of strings tokenizer: the tokenizer object model: the SPLADE model Returns: torch.Tensor of shape (batch_size, vocab_size) or None """ if tokenizer is None or model is None: return None inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): output = model(**inputs) if hasattr(output, 'logits'): # torch.max(..., dim=1)[0] reduces along sequence_length dimension, # resulting in (batch_size, vocab_size) splade_vectors = torch.max( torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1 )[0] return splade_vectors else: print("Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.") return None def get_splade_lexical_representation_internal(texts, tokenizer, model): """ Generates SPLADE-Lexical representations for a batch of texts. texts: list of strings tokenizer: the tokenizer object model: the SPLADE-Lexical model Returns: torch.Tensor of shape (batch_size, vocab_size) or None """ if tokenizer is None or model is None: return None inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): output = model(**inputs) if hasattr(output, 'logits'): splade_vectors = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0] vocab_size = tokenizer.vocab_size # create_lexical_bow_mask now returns (batch_size, vocab_size) bow_masks = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer) splade_vectors = splade_vectors * bow_masks # Element-wise multiplication, shapes (batch_size, vocab_size) return splade_vectors else: print("Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.") return None def get_splade_doc_representation_internal(texts, tokenizer, model): """ Generates SPLADE-Doc (binary) representations for a batch of texts. texts: list of strings tokenizer: the tokenizer object model: the SPLADE-Doc model (not directly used for logits, but for device) Returns: torch.Tensor of shape (batch_size, vocab_size) or None """ if tokenizer is None or model is None: return None inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} # Ensure inputs are on the correct device vocab_size = tokenizer.vocab_size # create_lexical_bow_mask now returns (batch_size, vocab_size) binary_splade_vectors = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer) return binary_splade_vectors # --- Document Indexing Function (now uses batching) --- def index_documents(doc_model_choice): global document_representations if document_representations: print("Documents already indexed. Skipping re-indexing.") return True tokenizer_to_use = None model_to_use = None representation_func_to_use = None if doc_model_choice == "SPLADE-cocondenser-distil": if tokenizer_splade is None or model_splade is None: print("SPLADE-cocondenser-distil model not loaded for indexing.") return False tokenizer_to_use = tokenizer_splade model_to_use = model_splade representation_func_to_use = get_splade_cocondenser_representation_internal elif doc_model_choice == "SPLADE-v3-Lexical": if tokenizer_splade_lexical is None or model_splade_lexical is None: print("SPLADE-v3-Lexical model not loaded for indexing.") return False tokenizer_to_use = tokenizer_splade_lexical model_to_use = model_splade_lexical representation_func_to_use = get_splade_lexical_representation_internal elif doc_model_choice == "SPLADE-v3-Doc": if tokenizer_splade_doc is None or model_splade_doc is None: print("SPLADE-v3-Doc model not loaded for indexing.") return False tokenizer_to_use = tokenizer_splade_doc model_to_use = model_splade_doc representation_func_to_use = get_splade_doc_representation_internal else: print(f"Invalid model choice for document indexing: {doc_model_choice}") return False print(f"Indexing documents using {doc_model_choice}...") doc_ids_list = list(document_texts.keys()) doc_texts_list = list(document_texts.values()) # --- BATCH SIZE FOR INDEXING --- batch_size = 32 # You can adjust this value based on memory and performance document_representations = {} # Ensure it's clear we're (re)building the index # Iterate through documents in batches for i in tqdm(range(0, len(doc_ids_list), batch_size), desc="Indexing Documents in Batches"): batch_doc_ids = doc_ids_list[i:i + batch_size] batch_doc_texts = doc_texts_list[i:i + batch_size] sparse_vectors_batch = representation_func_to_use(batch_doc_texts, tokenizer_to_use, model_to_use) if sparse_vectors_batch is not None: # sparse_vectors_batch will have shape (batch_size, vocab_size) for j, doc_id in enumerate(batch_doc_ids): # Store each document's vector document_representations[doc_id] = sparse_vectors_batch[j].cpu() else: print(f"Warning: Failed to get representation for a batch starting with doc_id {batch_doc_ids[0]}") print(f"Finished indexing {len(document_representations)} documents.") return True # --- Retrieval Function (for Retrieval Tab) --- def retrieve_documents(query_text, query_model_choice, indexed_doc_model_name, top_k=5): if not document_representations: return "Document index is not loaded or empty. Please ensure documents are indexed.", [] query_vector = None query_tokenizer = None query_model = None # These internal calls still use single text input for the query if query_model_choice == "SPLADE-cocondenser-distil (weighting and expansion)": query_tokenizer = tokenizer_splade query_model = model_splade query_vector = get_splade_cocondenser_representation_internal([query_text], query_tokenizer, query_model) elif query_model_choice == "SPLADE-v3-Lexical (weighting)": query_tokenizer = tokenizer_splade_lexical query_model = model_splade_lexical query_vector = get_splade_lexical_representation_internal([query_text], query_tokenizer, query_model) elif query_model_choice == "SPLADE-v3-Doc (binary)": query_tokenizer = tokenizer_splade_doc query_model = model_splade_doc query_vector = get_splade_doc_representation_internal([query_text], query_tokenizer, query_model) else: return "Invalid query model choice.", [] if query_vector is None: return "Failed to get query representation. Check console for model loading errors.", [] # Since internal functions now return batches, take the first (and only) item for single query query_vector = query_vector.squeeze(0).cpu() scores = {} for doc_id, doc_vec in document_representations.items(): score = torch.dot(query_vector, doc_vec).item() scores[doc_id] = score sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True) top_results = sorted_scores[:top_k] formatted_output = f"Retrieval Results for Query: '{query_text}'\n" formatted_output += f"Using Query Model: **{query_model_choice}**\n" formatted_output += f"Documents Indexed with: **{indexed_doc_model_name}**\n\n" if not top_results: formatted_output += "No documents found or scored.\n" else: for i, (doc_id, score) in enumerate(top_results): doc_text = document_texts.get(doc_id, "Document text not available.") formatted_output += f"**{i+1}. Document ID: {doc_id}** (Score: {score:.4f})\n" formatted_output += f"> {doc_text[:300]}...\n\n" return formatted_output, top_results # --- Unified Prediction Function for Gradio (for Retrieval Tab) --- def predict_retrieval_gradio(query_text, query_model_choice, selected_doc_model_display_only): formatted_output, _ = retrieve_documents(query_text, query_model_choice, initial_doc_model_for_indexing, top_k=5) return formatted_output # --- New function to get specific retrieval examples --- def get_specific_retrieval_examples(): if not queries_texts or not qrels_data or not document_texts: return "Queries, qrels, or documents not loaded. Please check initial loading." high_qrel_threshold = 3 # Relevance score of 3 or 4 for Cranfield is generally considered high low_qrel_threshold = 1 # Relevance score of 0 or 1 for Cranfield is generally considered low eligible_query_ids = [] for qid, qrels in qrels_data.items(): has_high_qrel = any(item['relevance'] >= high_qrel_threshold for item in qrels) has_low_qrel = any(item['relevance'] <= low_qrel_threshold for item in qrels) if has_high_qrel and has_low_qrel: eligible_query_ids.append(qid) if not eligible_query_ids: return "Could not find a query with both high and low relevance documents in the loaded qrels." # Pick a random eligible query random_query_id = random.choice(eligible_query_ids) full_query_text = queries_texts.get(random_query_id, "Query text not found.") query_snippet = full_query_text[:300] + "..." if len(full_query_text) > 300 else full_query_text qrels_for_query = qrels_data[random_query_id] high_qrel_docs = [item for item in qrels_for_query if item['relevance'] >= high_qrel_threshold] low_qrel_docs = [item for item in qrels_for_query if item['relevance'] <= low_qrel_threshold] selected_high_doc_id = random.choice(high_qrel_docs)['doc_id'] if high_qrel_docs else None selected_low_doc_id = random.choice(low_qrel_docs)['doc_id'] if low_qrel_docs else None output_str = f"### Random Query Example\n\n" output_str += f"**Query ID:** {random_query_id}\n" output_str += f"**Query Snippet:** {query_snippet}\n\n" # Changed to snippet if selected_high_doc_id: full_doc_text = document_texts.get(selected_high_doc_id, "Document text not available.") doc_snippet = full_doc_text[:500] + "..." if len(full_doc_text) > 500 else full_doc_text output_str += f"### Highly Relevant Document (Qrel >= {high_qrel_threshold})\n" output_str += f"**Document ID:** {selected_high_doc_id}\n" output_str += f"**Document Snippet:** {doc_snippet}\n\n" # Changed to snippet else: output_str += "No highly relevant document found for this query.\n\n" if selected_low_doc_id: full_doc_text = document_texts.get(selected_low_doc_id, "Document text not available.") doc_snippet = full_doc_text[:500] + "..." if len(full_doc_text) > 500 else full_doc_text output_str += f"### Lowly Relevant Document (Qrel <= {low_qrel_threshold})\n" output_str += f"**Document ID:** {selected_low_doc_id}\n" output_str += f"**Document Snippet:** {doc_snippet}\n\n" # Changed to snippet else: output_str += "No lowly relevant document found for this query.\n\n" return output_str # --- Initial Load and Indexing Calls --- # This part runs once when the app starts. load_cranfield_corpus_ir_datasets() if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None: index_documents(initial_doc_model_for_indexing) elif initial_doc_model_for_indexing == "SPLADE-v3-Lexical" and model_splade_lexical is not None: index_documents(initial_doc_model_for_indexing) elif initial_doc_model_for_indexing == "SPLADE-v3-Doc" and model_splade_doc is not None: index_documents(initial_doc_model_for_indexing) else: print(f"Skipping document indexing: Model '{initial_doc_model_for_indexing}' failed to load or is not a valid choice for indexing.") # --- Gradio Interface Setup with Tabs --- with gr.Blocks(title="SPLADE Demos") as demo: gr.Markdown("# 🌌 SPLADE Demos: Sparse Representation Explorer & Document Retrieval") gr.Markdown("Explore different SPLADE models and their sparse representation types, or perform document retrieval on a test collection.") with gr.Tabs(): with gr.TabItem("Sparse Representation Explorer"): gr.Markdown("### Explore Raw SPLADE Representations for Any Text") gr.Interface( fn=predict_representation_explorer, inputs=[ gr.Radio( [ "SPLADE-cocondenser-distil (weighting and expansion)", "SPLADE-v3-Lexical (weighting)", "SPLADE-v3-Doc (binary)" ], label="Choose Representation Model", value="SPLADE-cocondenser-distil (weighting and expansion)" ), gr.Textbox( lines=5, label="Enter your query or document text here:", placeholder="e.g., Why is Padua the nicest city in Italy?" ) ], outputs=gr.Markdown(), allow_flagging="never", # live=True # Setting live=True might be slow for complex models on every keystroke ) with gr.TabItem("Document Retrieval Demo"): gr.Markdown("### Retrieve Documents from Cranfield Collection") gr.Interface( fn=predict_retrieval_gradio, inputs=[ gr.Textbox( lines=3, label="Enter your query text here:", placeholder="e.g., Does high-dose vitamin C cure cancer?" ), gr.Radio( [ "SPLADE-cocondenser-distil (weighting and expansion)", "SPLADE-v3-Lexical (weighting)", "SPLADE-v3-Doc (binary)" ], label="Choose Query Representation Model", value="SPLADE-cocondenser-distil (weighting and expansion)" ), gr.Radio( [ "SPLADE-cocondenser-distil", "SPLADE-v3-Lexical", "SPLADE-v3-Doc" ], label=f"Document Index Model (Pre-indexed with: {initial_doc_model_for_indexing})", value=initial_doc_model_for_indexing, interactive=False # This radio is fixed for simplicity ) ], outputs=gr.Markdown(), allow_flagging="never", # live=True # retrieval is too heavy for live ) gr.Markdown("---") # Separator gr.Markdown("### Get Specific Retrieval Examples") specific_example_output = gr.Markdown() specific_example_button = gr.Button("Get Random Query with High/Low Qrel Docs") specific_example_button.click( fn=get_specific_retrieval_examples, inputs=[], outputs=specific_example_output ) demo.launch()