File size: 3,638 Bytes
00134aa
 
2413d91
00134aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5f72c9
00134aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import sys
import warnings
import spaces
import asyncio
from threading import Thread
from transformers import AsyncTextIteratorStreamer
from functools import partial

import gradio as gr
import torch
import numpy as np
from model import Rank1

print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")

# Suppress CUDA initialization warning
warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML")
    
@spaces.GPU
async def process_input(query: str, passage: str, stream: bool = True) -> tuple[str, str, str]:
    """Process input through the reranker and return formatted outputs."""
    try:
        reranker = Rank1(model_name_or_path="orionweller/rank1-32b-awq")
        async for result in reranker.predict(query, passage, streamer=stream):
            if result["is_relevant"] is None:
                # Intermediate streaming result
                yield "Processing...", "Processing...", result["model_reasoning"]
            else:
                # Final result
                relevance = "Relevant" if result["is_relevant"] else "Not Relevant"
                confidence = f"{result['confidence_score']:.2%}"
                reasoning = result["model_reasoning"]
                yield relevance, confidence, reasoning
    except Exception as e:
        yield f"Error: {str(e)}", "N/A", "An error occurred during processing"

# Example inputs
examples = [
    [
        "What movies were directed by James Cameron?",
        "Avatar: The Way of Water is a 2022 American epic science fiction film directed by James Cameron.",
    ],
    [
        "What are the symptoms of COVID-19?",
        "Common symptoms of COVID-19 include fever, cough, fatigue, loss of taste or smell, and difficulty breathing.",
    ]
]

theme = gr.themes.Soft(
    primary_hue="indigo",
    font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"],
    neutral_hue="slate",
    radius_size="lg",
)

with gr.Blocks(theme=theme, css=".red-text { color: red; }") as demo:
    gr.Markdown("# Rank1: Test Time Compute in Reranking")
    gr.HTML('NOTE: for demo purposes this is a <span style="color: red;">quantized</span> model with a <span style="color: red;">1024</span> context length. HF spaces cannot use vLLM so this is <span style="color: red;">significantly slower</span>')
    
    with gr.Row():
        with gr.Column():
            query_input = gr.Textbox(
                label="Query",
                placeholder="Enter your search query here",
                lines=2
            )
            passage_input = gr.Textbox(
                label="Passage",
                placeholder="Enter the passage to check for relevance",
                lines=6
            )
            submit_button = gr.Button("Check Relevance")
        
        with gr.Column():
            relevance_output = gr.Textbox(label="Relevance")
            confidence_output = gr.Textbox(label="Confidence")
            reasoning_output = gr.Textbox(
                label="Model Reasoning",
                lines=10,
                interactive=False
            )
    
    gr.Examples(
        examples=examples,
        inputs=[query_input, passage_input],
        outputs=[relevance_output, confidence_output, reasoning_output],
        fn=partial(process_input, stream=False),
        cache_examples=True,
    )
    
    submit_button.click(
        fn=process_input,
        inputs=[query_input, passage_input],
        outputs=[relevance_output, confidence_output, reasoning_output],
        api_name="predict",
        queue=True
    )

if __name__ == "__main__":
    demo.launch(share=True)