import sys import warnings import spaces from threading import Thread from transformers import TextIteratorStreamer from functools import partial import gradio as gr import torch import numpy as np from model import Rank1 import math print(f"NumPy version: {np.__version__}") print(f"PyTorch version: {torch.__version__}") # Suppress CUDA initialization warning warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML") @spaces.GPU def process_input(query: str, passage: str, stream: bool = True) -> tuple[str, str, str]: """Process input through the reranker and return formatted outputs.""" reranker = Rank1(model_name_or_path="orionweller/rank1-32b-awq") prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \ f"Query: {query}\n" \ f"Passage: {passage}\n" \ "" reranker.model = reranker.model.to("cuda") inputs = reranker.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=reranker.context_size ).to("cuda") if stream: streamer = TextIteratorStreamer( reranker.tokenizer, skip_prompt=True, skip_special_tokens=True ) current_text = "" generation_output = None def generate_with_output(): nonlocal generation_output generation_output = reranker.model.generate( **inputs, generation_config=reranker.generation_config, stopping_criteria=reranker.stopping_criteria, return_dict_in_generate=True, output_scores=True, streamer=streamer ) thread = Thread(target=generate_with_output) thread.start() # Stream tokens as they're generated for new_text in streamer: current_text += new_text yield ( "Processing...", "Processing...", current_text ) thread.join() # Add the stopping sequence and calculate final scores current_text += "\n" + reranker.stopping_criteria[0].matched_sequence with torch.no_grad(): final_scores = generation_output.scores[-1][0] true_logit = final_scores[reranker.true_token].item() false_logit = final_scores[reranker.false_token].item() true_score = math.exp(true_logit) false_score = math.exp(false_logit) score = true_score / (true_score + false_score) yield ( score > 0.5, score, current_text ) else: # Non-streaming mode with torch.no_grad(): outputs = reranker.model.generate( **inputs, generation_config=reranker.generation_config, stopping_criteria=reranker.stopping_criteria, return_dict_in_generate=True, output_scores=True ) # Get final score from generation outputs final_scores = outputs.scores[-1][0] # Get logits from last position true_logit = final_scores[reranker.true_token].item() false_logit = final_scores[reranker.false_token].item() true_score = math.exp(true_logit) false_score = math.exp(false_logit) score = true_score / (true_score + false_score) # only decode the generated text new_text = outputs.sequences[0][len(inputs.input_ids[0]):] decoded_input = reranker.tokenizer.decode(new_text) output_reasoning = "\n" + decoded_input.strip() + f"\n {'true' if score > 0.5 else 'false'}" yield ( "Relevant" if score > 0.5 else "Not Relevant", f"{score:.2%}", output_reasoning ) # Example inputs examples = [ [ "What movies were directed by James Cameron?", "Avatar: The Way of Water is a 2022 American epic science fiction film directed by James Cameron.", ], [ "What are the symptoms of COVID-19?", "Common symptoms of COVID-19 include fever, cough, fatigue, loss of taste or smell, and difficulty breathing.", ] ] theme = gr.themes.Soft( primary_hue="indigo", font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"], neutral_hue="slate", radius_size="lg", ) with gr.Blocks(theme=theme, css=".red-text { color: red; }") as demo: gr.Markdown("# Rank1: Test Time Compute in Reranking") gr.HTML('NOTE: for demo purposes this is a quantized model with a 1024 context length. HF spaces cannot use vLLM so this is significantly slower') with gr.Row(): with gr.Column(): query_input = gr.Textbox( label="Query", placeholder="Enter your search query here", lines=2 ) passage_input = gr.Textbox( label="Passage", placeholder="Enter the passage to check for relevance", lines=6 ) submit_button = gr.Button("Check Relevance") with gr.Column(): relevance_output = gr.Textbox(label="Relevance") confidence_output = gr.Textbox(label="Confidence") reasoning_output = gr.Textbox( label="Model Reasoning", lines=10, interactive=False ) gr.Examples( examples=examples, inputs=[query_input, passage_input], outputs=[relevance_output, confidence_output, reasoning_output], fn=partial(process_input, stream=False), cache_examples=True, ) submit_button.click( fn=process_input, inputs=[query_input, passage_input], outputs=[relevance_output, confidence_output, reasoning_output], api_name="predict", queue=True ) if __name__ == "__main__": demo.launch(share=True)