import gradio as gr from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel import torch # --- Model Loading --- tokenizer_splade = None model_splade = None tokenizer_splade_lexical = None model_splade_lexical = None # Load SPLADE v3 model (original) try: tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil") model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil") model_splade.eval() # Set to evaluation mode for inference print("SPLADE v3 (cocondenser) model loaded successfully!") except Exception as e: print(f"Error loading SPLADE (cocondenser) model: {e}") print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.") # Load SPLADE v3 Lexical model try: splade_lexical_model_name = "naver/splade-v3-lexical" tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name) model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name) model_splade_lexical.eval() # Set to evaluation mode for inference print(f"SPLADE v3 Lexical model '{splade_lexical_model_name}' loaded successfully!") except Exception as e: print(f"Error loading SPLADE v3 Lexical model: {e}") print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).") # --- Core Representation Functions --- def get_splade_representation(text): if tokenizer_splade is None or model_splade is None: return "SPLADE (cocondenser) model is not loaded. Please check the console for loading errors." inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model_splade.device) for k, v in inputs.items()} with torch.no_grad(): output = model_splade(**inputs) if hasattr(output, 'logits'): splade_vector = torch.max( torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1 )[0].squeeze() else: return "Model output structure not as expected for SPLADE (cocondenser). 'logits' not found." indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() if not isinstance(indices, list): indices = [indices] values = splade_vector[indices].cpu().tolist() token_weights = dict(zip(indices, values)) meaningful_tokens = {} for token_id, weight in token_weights.items(): decoded_token = tokenizer_splade.decode([token_id]) if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: meaningful_tokens[decoded_token] = weight sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) formatted_output = "SPLADE (cocondenser) Representation (All Non-Zero Terms):\n" if not sorted_representation: formatted_output += "No significant terms found for this input.\n" else: for term, weight in sorted_representation: formatted_output += f"- **{term}**: {weight:.4f}\n" formatted_output += "\n--- Raw SPLADE Vector Info ---\n" formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n" return formatted_output def get_splade_lexical_representation(text): if tokenizer_splade_lexical is None or model_splade_lexical is None: return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors." inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()} with torch.no_grad(): output = model_splade_lexical(**inputs) if hasattr(output, 'logits'): splade_vector = torch.max( torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1 )[0].squeeze() else: return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found." indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() if not isinstance(indices, list): indices = [indices] values = splade_vector[indices].cpu().tolist() token_weights = dict(zip(indices, values)) meaningful_tokens = {} for token_id, weight in token_weights.items(): decoded_token = tokenizer_splade_lexical.decode([token_id]) if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: meaningful_tokens[decoded_token] = weight sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) formatted_output = "SPLADE v3 Lexical Representation (All Non-Zero Terms):\n" if not sorted_representation: formatted_output += "No significant terms found for this input.\n" else: for term, weight in sorted_representation: formatted_output += f"- **{term}**: {weight:.4f}\n" formatted_output += "\n--- Raw SPLADE Vector Info ---\n" formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_lexical.vocab_size):.2%}\n" return formatted_output # --- Unified Prediction Function for Gradio --- def predict_representation(model_choice, text): if model_choice == "SPLADE (cocondenser)": return get_splade_representation(text) elif model_choice == "SPLADE-v3-Lexical": return get_splade_lexical_representation(text) else: return "Please select a model." # --- Gradio Interface Setup --- demo = gr.Interface( fn=predict_representation, inputs=[ gr.Radio( ["SPLADE (cocondenser)", "SPLADE-v3-Lexical"], # Updated options label="Choose Representation Model", value="SPLADE (cocondenser)" # Default selection ), gr.Textbox( lines=5, label="Enter your query or document text here:", placeholder="e.g., Why is Padua the nicest city in Italy?" ) ], outputs=gr.Markdown(), title="🌌 Sparse and Binary Sparse Representation Generator", description="Enter any text to see its SPLADE sparse vector or SPLADE-v3-Lexical representation.", allow_flagging="never" ) # Launch the Gradio app demo.launch()