import sys import os import shutil import warnings import spaces from threading import Thread from transformers import TextIteratorStreamer from functools import partial from huggingface_hub import snapshot_download import gradio as gr import torch import numpy as np from model import Rank1 import math print(f"NumPy version: {np.__version__}") print(f"PyTorch version: {torch.__version__}") # Suppress CUDA initialization warning warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML") MODEL_PATH = None reranker = None @spaces.GPU def process_input(query: str, passage: str) -> tuple[str, str, str]: """Process input through the reranker and return formatted outputs.""" global MODEL_PATH global reranker prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \ f"Query: {query}\n" \ f"Passage: {passage}\n" \ "" reranker.model = reranker.model.to("cuda") inputs = reranker.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=reranker.context_size ).to("cuda") streamer = TextIteratorStreamer( reranker.tokenizer, skip_prompt=True, skip_special_tokens=False ) current_text = "" generation_output = None def generate_with_output(): nonlocal generation_output generation_output = reranker.model.generate( **inputs, generation_config=reranker.generation_config, stopping_criteria=reranker.stopping_criteria, return_dict_in_generate=True, output_scores=True, streamer=streamer ) thread = Thread(target=generate_with_output) thread.start() # Stream tokens as they're generated for new_text in streamer: current_text += new_text yield ( "Processing...", "Processing...", current_text ) thread.join() # Add the stopping sequence and calculate final scores if "" not in current_text: current_text += "\n" + reranker.stopping_criteria[0].matched_sequence with torch.no_grad(): final_scores = generation_output.scores[-1][0] true_logit = final_scores[reranker.true_token].item() false_logit = final_scores[reranker.false_token].item() true_score = math.exp(true_logit) false_score = math.exp(false_logit) score = true_score / (true_score + false_score) yield ( score > 0.5, score, current_text ) # Example inputs examples = [ [ "What movies were directed by James Cameron?", "Avatar: The Way of Water is a 2022 American epic science fiction film directed by James Cameron.", ], [ "What movies were directed by James Cameron?", "Common symptoms of COVID-19 include fever, cough, fatigue, loss of taste or smell, and difficulty breathing.", ] ] theme = gr.themes.Soft( primary_hue="indigo", font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"], neutral_hue="slate", radius_size="lg", ) with gr.Blocks(theme=theme, css=".red-text { color: red; }") as demo: gr.Markdown("# Rank1: Test Time Compute in Reranking") gr.HTML('NOTE: for demo purposes this is a quantized model limited to a 1024 context length. HF spaces cannot use vLLM so this is significantly slower') gr.HTML('📄 Paper Link: https://arxiv.org/abs/2502.18418') with gr.Row(): with gr.Column(): query_input = gr.Textbox( label="Query", placeholder="Enter your search query here", lines=2 ) passage_input = gr.Textbox( label="Passage", placeholder="Enter the passage to check for relevance", lines=6 ) submit_button = gr.Button("Check Relevance") with gr.Column(): relevance_output = gr.Textbox(label="Relevance") confidence_output = gr.Textbox(label="Confidence") reasoning_output = gr.Textbox( label="Model Reasoning", lines=10, interactive=False ) gr.Examples( examples=examples, inputs=[query_input, passage_input], outputs=[relevance_output, confidence_output, reasoning_output], fn=process_input, cache_examples=True, ) submit_button.click( fn=process_input, inputs=[query_input, passage_input], outputs=[relevance_output, confidence_output, reasoning_output], api_name="predict", queue=True ) if __name__ == "__main__": # download model first, so we don't have to wait for it MODEL_PATH = snapshot_download( repo_id="orionweller/rank1-7b-awq", ) print(f"Downloaded model to: {MODEL_PATH}") reranker = Rank1(model_name_or_path=MODEL_PATH) demo.launch(share=False)