oweller2
working without async
00588f0
raw
history blame
6.3 kB
import sys
import warnings
import spaces
from threading import Thread
from transformers import TextIteratorStreamer
from functools import partial
import gradio as gr
import torch
import numpy as np
from model import Rank1
import math
print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")
# Suppress CUDA initialization warning
warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML")
@spaces.GPU
def process_input(query: str, passage: str, stream: bool = True) -> tuple[str, str, str]:
"""Process input through the reranker and return formatted outputs."""
reranker = Rank1(model_name_or_path="orionweller/rank1-32b-awq")
prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \
f"Query: {query}\n" \
f"Passage: {passage}\n" \
"<think>"
reranker.model = reranker.model.to("cuda")
inputs = reranker.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=reranker.context_size
).to("cuda")
if stream:
streamer = TextIteratorStreamer(
reranker.tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
current_text = "<think>"
generation_output = None
def generate_with_output():
nonlocal generation_output
generation_output = reranker.model.generate(
**inputs,
generation_config=reranker.generation_config,
stopping_criteria=reranker.stopping_criteria,
return_dict_in_generate=True,
output_scores=True,
streamer=streamer
)
thread = Thread(target=generate_with_output)
thread.start()
# Stream tokens as they're generated
for new_text in streamer:
current_text += new_text
yield (
"Processing...",
"Processing...",
current_text
)
thread.join()
# Add the stopping sequence and calculate final scores
current_text += "\n" + reranker.stopping_criteria[0].matched_sequence
with torch.no_grad():
final_scores = generation_output.scores[-1][0]
true_logit = final_scores[reranker.true_token].item()
false_logit = final_scores[reranker.false_token].item()
true_score = math.exp(true_logit)
false_score = math.exp(false_logit)
score = true_score / (true_score + false_score)
yield (
score > 0.5,
score,
current_text
)
else:
# Non-streaming mode
with torch.no_grad():
outputs = reranker.model.generate(
**inputs,
generation_config=reranker.generation_config,
stopping_criteria=reranker.stopping_criteria,
return_dict_in_generate=True,
output_scores=True
)
# Get final score from generation outputs
final_scores = outputs.scores[-1][0] # Get logits from last position
true_logit = final_scores[reranker.true_token].item()
false_logit = final_scores[reranker.false_token].item()
true_score = math.exp(true_logit)
false_score = math.exp(false_logit)
score = true_score / (true_score + false_score)
# only decode the generated text
new_text = outputs.sequences[0][len(inputs.input_ids[0]):]
decoded_input = reranker.tokenizer.decode(new_text)
output_reasoning = "<think>\n" + decoded_input.strip() + f"\n</think> {'true' if score > 0.5 else 'false'}"
yield (
"Relevant" if score > 0.5 else "Not Relevant",
f"{score:.2%}",
output_reasoning
)
# Example inputs
examples = [
[
"What movies were directed by James Cameron?",
"Avatar: The Way of Water is a 2022 American epic science fiction film directed by James Cameron.",
],
[
"What are the symptoms of COVID-19?",
"Common symptoms of COVID-19 include fever, cough, fatigue, loss of taste or smell, and difficulty breathing.",
]
]
theme = gr.themes.Soft(
primary_hue="indigo",
font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"],
neutral_hue="slate",
radius_size="lg",
)
with gr.Blocks(theme=theme, css=".red-text { color: red; }") as demo:
gr.Markdown("# Rank1: Test Time Compute in Reranking")
gr.HTML('NOTE: for demo purposes this is a <span style="color: red;">quantized</span> model with a <span style="color: red;">1024</span> context length. HF spaces cannot use vLLM so this is <span style="color: red;">significantly slower</span>')
with gr.Row():
with gr.Column():
query_input = gr.Textbox(
label="Query",
placeholder="Enter your search query here",
lines=2
)
passage_input = gr.Textbox(
label="Passage",
placeholder="Enter the passage to check for relevance",
lines=6
)
submit_button = gr.Button("Check Relevance")
with gr.Column():
relevance_output = gr.Textbox(label="Relevance")
confidence_output = gr.Textbox(label="Confidence")
reasoning_output = gr.Textbox(
label="Model Reasoning",
lines=10,
interactive=False
)
gr.Examples(
examples=examples,
inputs=[query_input, passage_input],
outputs=[relevance_output, confidence_output, reasoning_output],
fn=partial(process_input, stream=False),
cache_examples=True,
)
submit_button.click(
fn=process_input,
inputs=[query_input, passage_input],
outputs=[relevance_output, confidence_output, reasoning_output],
api_name="predict",
queue=True
)
if __name__ == "__main__":
demo.launch(share=True)