import sys import warnings import spaces from threading import Thread from transformers import TextIteratorStreamer from functools import partial from huggingface_hub import snapshot_download import gradio as gr import torch import numpy as np from model import Rank1 import math print(f"NumPy version: {np.__version__}") print(f"PyTorch version: {torch.__version__}") # Suppress CUDA initialization warning warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML") MODEL_PATH = None @spaces.GPU def process_input(query: str, passage: str, stream: bool = True) -> tuple[str, str, str]: """Process input through the reranker and return formatted outputs.""" global MODEL_PATH reranker = Rank1(model_name_or_path=MODEL_PATH) prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \ f"Query: {query}\n" \ f"Passage: {passage}\n" \ "" reranker.model = reranker.model.to("cuda") inputs = reranker.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=reranker.context_size ).to("cuda") if stream: streamer = TextIteratorStreamer( reranker.tokenizer, skip_prompt=True, skip_special_tokens=True ) current_text = "" generation_output = None def generate_with_output(): nonlocal generation_output generation_output = reranker.model.generate( **inputs, generation_config=reranker.generation_config, stopping_criteria=reranker.stopping_criteria, return_dict_in_generate=True, output_scores=True, streamer=streamer ) thread = Thread(target=generate_with_output) thread.start() # Stream tokens as they're generated for new_text in streamer: current_text += new_text yield ( "Processing...", "Processing...", current_text ) thread.join() # Add the stopping sequence and calculate final scores current_text += "\n" + reranker.stopping_criteria[0].matched_sequence with torch.no_grad(): final_scores = generation_output.scores[-1][0] true_logit = final_scores[reranker.true_token].item() false_logit = final_scores[reranker.false_token].item() true_score = math.exp(true_logit) false_score = math.exp(false_logit) score = true_score / (true_score + false_score) yield ( score > 0.5, score, current_text ) else: # Non-streaming mode with torch.no_grad(): outputs = reranker.model.generate( **inputs, generation_config=reranker.generation_config, stopping_criteria=reranker.stopping_criteria, return_dict_in_generate=True, output_scores=True ) # Get final score from generation outputs final_scores = outputs.scores[-1][0] # Get logits from last position true_logit = final_scores[reranker.true_token].item() false_logit = final_scores[reranker.false_token].item() true_score = math.exp(true_logit) false_score = math.exp(false_logit) score = true_score / (true_score + false_score) # only decode the generated text new_text = outputs.sequences[0][len(inputs.input_ids[0]):] decoded_input = reranker.tokenizer.decode(new_text) output_reasoning = "\n" + decoded_input.strip() + f"\n {'true' if score > 0.5 else 'false'}" yield ( "Relevant" if score > 0.5 else "Not Relevant", f"{score:.2%}", output_reasoning ) # Example inputs examples = [ [ "What movies were directed by James Cameron?", "Avatar: The Way of Water is a 2022 American epic science fiction film directed by James Cameron.", ], [ "What are the symptoms of COVID-19?", "Common symptoms of COVID-19 include fever, cough, fatigue, loss of taste or smell, and difficulty breathing.", ] ] theme = gr.themes.Soft( primary_hue="indigo", font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"], neutral_hue="slate", radius_size="lg", ) with gr.Blocks(theme=theme, css=".red-text { color: red; }") as demo: gr.Markdown("# Rank1: Test Time Compute in Reranking") gr.HTML('NOTE: for demo purposes this is a quantized model with a 1024 context length. HF spaces cannot use vLLM so this is significantly slower') with gr.Row(): with gr.Column(): query_input = gr.Textbox( label="Query", placeholder="Enter your search query here", lines=2 ) passage_input = gr.Textbox( label="Passage", placeholder="Enter the passage to check for relevance", lines=6 ) submit_button = gr.Button("Check Relevance") with gr.Column(): relevance_output = gr.Textbox(label="Relevance") confidence_output = gr.Textbox(label="Confidence") reasoning_output = gr.Textbox( label="Model Reasoning", lines=10, interactive=False ) gr.Examples( examples=examples, inputs=[query_input, passage_input], outputs=[relevance_output, confidence_output, reasoning_output], fn=partial(process_input, stream=False), cache_examples=True, ) submit_button.click( fn=process_input, inputs=[query_input, passage_input], outputs=[relevance_output, confidence_output, reasoning_output], api_name="predict", queue=True ) if __name__ == "__main__": # download model first, so we don't have to wait for it MODEL_PATH = snapshot_download( repo_id="orionweller/rank1-32b-awq", ) print(f"Downloaded model to: {MODEL_PATH}") demo.launch(share=True)