import sys import warnings import spaces import asyncio from threading import Thread from transformers import AsyncTextIteratorStreamer from functools import partial import gradio as gr import torch import numpy as np from model import Rank1 print(f"NumPy version: {np.__version__}") print(f"PyTorch version: {torch.__version__}") # Suppress CUDA initialization warning warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML") @spaces.GPU async def process_input(query: str, passage: str, stream: bool = True) -> tuple[str, str, str]: """Process input through the reranker and return formatted outputs.""" try: reranker = Rank1(model_name_or_path="orionweller/rank1-32b-awq") async for result in reranker.predict(query, passage, streamer=stream): if result["is_relevant"] is None: # Intermediate streaming result yield "Processing...", "Processing...", result["model_reasoning"] else: # Final result relevance = "Relevant" if result["is_relevant"] else "Not Relevant" confidence = f"{result['confidence_score']:.2%}" reasoning = result["model_reasoning"] yield relevance, confidence, reasoning except Exception as e: yield f"Error: {str(e)}", "N/A", "An error occurred during processing" # Example inputs examples = [ [ "What movies were directed by James Cameron?", "Avatar: The Way of Water is a 2022 American epic science fiction film directed by James Cameron.", ], [ "What are the symptoms of COVID-19?", "Common symptoms of COVID-19 include fever, cough, fatigue, loss of taste or smell, and difficulty breathing.", ] ] theme = gr.themes.Soft( primary_hue="indigo", font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"], neutral_hue="slate", radius_size="lg", ) with gr.Blocks(theme=theme, css=".red-text { color: red; }") as demo: gr.Markdown("# Rank1: Test Time Compute in Reranking") gr.HTML('NOTE: for demo purposes this is a quantized model with a 1024 context length. HF spaces cannot use vLLM so this is significantly slower') with gr.Row(): with gr.Column(): query_input = gr.Textbox( label="Query", placeholder="Enter your search query here", lines=2 ) passage_input = gr.Textbox( label="Passage", placeholder="Enter the passage to check for relevance", lines=6 ) submit_button = gr.Button("Check Relevance") with gr.Column(): relevance_output = gr.Textbox(label="Relevance") confidence_output = gr.Textbox(label="Confidence") reasoning_output = gr.Textbox( label="Model Reasoning", lines=10, interactive=False ) gr.Examples( examples=examples, inputs=[query_input, passage_input], outputs=[relevance_output, confidence_output, reasoning_output], fn=partial(process_input, stream=False), cache_examples=True, ) submit_button.click( fn=process_input, inputs=[query_input, passage_input], outputs=[relevance_output, confidence_output, reasoning_output], api_name="predict", queue=True ) if __name__ == "__main__": demo.launch(share=True)