orionweller's picture
Update app.py
25beaa7 verified
raw
history blame
5.21 kB
import sys
import os
import shutil
import warnings
import spaces
from threading import Thread
from transformers import TextIteratorStreamer
from functools import partial
from huggingface_hub import snapshot_download
import gradio as gr
import torch
import numpy as np
from model import Rank1
import math
print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")
# Suppress CUDA initialization warning
warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML")
MODEL_PATH = None
reranker = None
@spaces.GPU
def process_input(query: str, passage: str) -> tuple[str, str, str]:
"""Process input through the reranker and return formatted outputs."""
global MODEL_PATH
global reranker
prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \
f"Query: {query}\n" \
f"Passage: {passage}\n" \
"<think>"
reranker.model = reranker.model.to("cuda")
inputs = reranker.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=reranker.context_size
).to("cuda")
streamer = TextIteratorStreamer(
reranker.tokenizer,
skip_prompt=True,
skip_special_tokens=False
)
current_text = "<think>"
generation_output = None
def generate_with_output():
nonlocal generation_output
generation_output = reranker.model.generate(
**inputs,
generation_config=reranker.generation_config,
stopping_criteria=reranker.stopping_criteria,
return_dict_in_generate=True,
output_scores=True,
streamer=streamer
)
thread = Thread(target=generate_with_output)
thread.start()
# Stream tokens as they're generated
for new_text in streamer:
current_text += new_text
yield (
"Processing...",
"Processing...",
current_text
)
thread.join()
# Add the stopping sequence and calculate final scores
if "</think>" not in current_text:
current_text += "\n" + reranker.stopping_criteria[0].matched_sequence
with torch.no_grad():
final_scores = generation_output.scores[-1][0]
true_logit = final_scores[reranker.true_token].item()
false_logit = final_scores[reranker.false_token].item()
true_score = math.exp(true_logit)
false_score = math.exp(false_logit)
score = true_score / (true_score + false_score)
yield (
score > 0.5,
score,
current_text
)
# Example inputs
examples = [
[
"What movies were directed by James Cameron?",
"Avatar: The Way of Water is a 2022 American epic science fiction film directed by James Cameron.",
],
[
"What movies were directed by James Cameron?",
"Common symptoms of COVID-19 include fever, cough, fatigue, loss of taste or smell, and difficulty breathing.",
]
]
theme = gr.themes.Soft(
primary_hue="indigo",
font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"],
neutral_hue="slate",
radius_size="lg",
)
with gr.Blocks(theme=theme, css=".red-text { color: red; }") as demo:
gr.Markdown("# Rank1: Test Time Compute in Reranking")
gr.HTML('NOTE: for demo purposes this is a <span style="color: red;">quantized</span> model limited to a <span style="color: red;">1024</span> context length. HF spaces cannot use vLLM so this is <span style="color: red;">significantly slower</span>')
gr.HTML('πŸ“„ Paper Link: https://arxiv.org/abs/2502.18418')
with gr.Row():
with gr.Column():
query_input = gr.Textbox(
label="Query",
placeholder="Enter your search query here",
lines=2
)
passage_input = gr.Textbox(
label="Passage",
placeholder="Enter the passage to check for relevance",
lines=6
)
submit_button = gr.Button("Check Relevance")
with gr.Column():
relevance_output = gr.Textbox(label="Relevance")
confidence_output = gr.Textbox(label="Confidence")
reasoning_output = gr.Textbox(
label="Model Reasoning",
lines=10,
interactive=False
)
gr.Examples(
examples=examples,
inputs=[query_input, passage_input],
outputs=[relevance_output, confidence_output, reasoning_output],
fn=process_input,
cache_examples=True,
)
submit_button.click(
fn=process_input,
inputs=[query_input, passage_input],
outputs=[relevance_output, confidence_output, reasoning_output],
api_name="predict",
queue=True
)
if __name__ == "__main__":
# download model first, so we don't have to wait for it
MODEL_PATH = snapshot_download(
repo_id="orionweller/rank1-7b-awq",
)
print(f"Downloaded model to: {MODEL_PATH}")
reranker = Rank1(model_name_or_path=MODEL_PATH)
demo.launch(share=False)