orionweller's picture
Update app.py
2413d91 verified
raw
history blame
3.64 kB
import sys
import warnings
import spaces
import asyncio
from threading import Thread
from transformers import AsyncTextIteratorStreamer
from functools import partial
import gradio as gr
import torch
import numpy as np
from model import Rank1
print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")
# Suppress CUDA initialization warning
warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML")
@spaces.GPU
async def process_input(query: str, passage: str, stream: bool = True) -> tuple[str, str, str]:
"""Process input through the reranker and return formatted outputs."""
try:
reranker = Rank1(model_name_or_path="orionweller/rank1-32b-awq")
async for result in reranker.predict(query, passage, streamer=stream):
if result["is_relevant"] is None:
# Intermediate streaming result
yield "Processing...", "Processing...", result["model_reasoning"]
else:
# Final result
relevance = "Relevant" if result["is_relevant"] else "Not Relevant"
confidence = f"{result['confidence_score']:.2%}"
reasoning = result["model_reasoning"]
yield relevance, confidence, reasoning
except Exception as e:
yield f"Error: {str(e)}", "N/A", "An error occurred during processing"
# Example inputs
examples = [
[
"What movies were directed by James Cameron?",
"Avatar: The Way of Water is a 2022 American epic science fiction film directed by James Cameron.",
],
[
"What are the symptoms of COVID-19?",
"Common symptoms of COVID-19 include fever, cough, fatigue, loss of taste or smell, and difficulty breathing.",
]
]
theme = gr.themes.Soft(
primary_hue="indigo",
font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"],
neutral_hue="slate",
radius_size="lg",
)
with gr.Blocks(theme=theme, css=".red-text { color: red; }") as demo:
gr.Markdown("# Rank1: Test Time Compute in Reranking")
gr.HTML('NOTE: for demo purposes this is a <span style="color: red;">quantized</span> model with a <span style="color: red;">1024</span> context length. HF spaces cannot use vLLM so this is <span style="color: red;">significantly slower</span>')
with gr.Row():
with gr.Column():
query_input = gr.Textbox(
label="Query",
placeholder="Enter your search query here",
lines=2
)
passage_input = gr.Textbox(
label="Passage",
placeholder="Enter the passage to check for relevance",
lines=6
)
submit_button = gr.Button("Check Relevance")
with gr.Column():
relevance_output = gr.Textbox(label="Relevance")
confidence_output = gr.Textbox(label="Confidence")
reasoning_output = gr.Textbox(
label="Model Reasoning",
lines=10,
interactive=False
)
gr.Examples(
examples=examples,
inputs=[query_input, passage_input],
outputs=[relevance_output, confidence_output, reasoning_output],
fn=partial(process_input, stream=False),
cache_examples=True,
)
submit_button.click(
fn=process_input,
inputs=[query_input, passage_input],
outputs=[relevance_output, confidence_output, reasoning_output],
api_name="predict",
queue=True
)
if __name__ == "__main__":
demo.launch(share=True)