import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Load model and tokenizer model_name = "castorini/monot5-small-msmarco-10k" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Define reranking function def rerank(query, documents): documents = documents.split("\n") # Split documents by newlines reranked_results = [] for doc in documents: # Combine query and document into a single input input_text = f"Query: {query} Document: {doc} Relevant:" inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) outputs = model.generate(**inputs) # Decode the output relevance = tokenizer.decode(outputs[0], skip_special_tokens=True) reranked_results.append((doc, relevance)) # Sort by relevance (assuming higher is better) reranked_results.sort(key=lambda x: x[1], reverse=True) return "\n".join([f"{doc} (Relevance: {rel})" for doc, rel in reranked_results]) # Create Gradio interface interface = gr.Interface( fn=rerank, inputs=[ gr.Textbox(label="Query", placeholder="Enter your query"), gr.Textbox(label="Documents (one per line)", lines=5, placeholder="Enter documents to rank"), ], outputs=gr.Textbox(label="Reranked Documents"), title="MonoT5 Reranking", description="Provide a query and a list of documents to rerank them using MonoT5." ) # Launch the app if __name__ == "__main__": interface.launch()