oweller2 commited on
Commit
00134aa
Β·
1 Parent(s): ff7a0f2
Files changed (4) hide show
  1. README.md +5 -4
  2. app.py +105 -0
  3. model.py +185 -0
  4. requirements.txt +8 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Rank1 Demo
3
- emoji: 🐠
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.17.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Rank1: Test Time Compute in Reranking
3
+ emoji: πŸ†
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.17.1
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import warnings
3
+ import asyncio
4
+ from threading import Thread
5
+ from transformers import AsyncTextIteratorStreamer
6
+ from functools import partial
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import numpy as np
11
+ from model import Rank1
12
+
13
+ print(f"NumPy version: {np.__version__}")
14
+ print(f"PyTorch version: {torch.__version__}")
15
+
16
+ # Suppress CUDA initialization warning
17
+ warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML")
18
+
19
+ try:
20
+ reranker = Rank1(model_name_or_path="orionweller/rank1-32b-awq")
21
+ except Exception as e:
22
+ print(f"Error loading model: {e}")
23
+ sys.exit(1)
24
+
25
+ @spaces.GPU
26
+ async def process_input(query: str, passage: str, stream: bool = True) -> tuple[str, str, str]:
27
+ """Process input through the reranker and return formatted outputs."""
28
+ try:
29
+ async for result in reranker.predict(query, passage, streamer=stream):
30
+ if result["is_relevant"] is None:
31
+ # Intermediate streaming result
32
+ yield "Processing...", "Processing...", result["model_reasoning"]
33
+ else:
34
+ # Final result
35
+ relevance = "Relevant" if result["is_relevant"] else "Not Relevant"
36
+ confidence = f"{result['confidence_score']:.2%}"
37
+ reasoning = result["model_reasoning"]
38
+ yield relevance, confidence, reasoning
39
+ except Exception as e:
40
+ yield f"Error: {str(e)}", "N/A", "An error occurred during processing"
41
+
42
+ # Example inputs
43
+ examples = [
44
+ [
45
+ "What movies were directed by James Cameron?",
46
+ "Avatar: The Way of Water is a 2022 American epic science fiction film directed by James Cameron.",
47
+ ],
48
+ [
49
+ "What are the symptoms of COVID-19?",
50
+ "Common symptoms of COVID-19 include fever, cough, fatigue, loss of taste or smell, and difficulty breathing.",
51
+ ]
52
+ ]
53
+
54
+ theme = gr.themes.Soft(
55
+ primary_hue="indigo",
56
+ font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"],
57
+ neutral_hue="slate",
58
+ radius_size="lg",
59
+ )
60
+
61
+ with gr.Blocks(theme=theme, css=".red-text { color: red; }") as demo:
62
+ gr.Markdown("# Rank1: Test Time Compute in Reranking")
63
+ gr.HTML('NOTE: for demo purposes this is a <span style="color: red;">quantized</span> model with a <span style="color: red;">1024</span> context length. HF spaces cannot use vLLM so this is <span style="color: red;">significantly slower</span>')
64
+
65
+ with gr.Row():
66
+ with gr.Column():
67
+ query_input = gr.Textbox(
68
+ label="Query",
69
+ placeholder="Enter your search query here",
70
+ lines=2
71
+ )
72
+ passage_input = gr.Textbox(
73
+ label="Passage",
74
+ placeholder="Enter the passage to check for relevance",
75
+ lines=6
76
+ )
77
+ submit_button = gr.Button("Check Relevance")
78
+
79
+ with gr.Column():
80
+ relevance_output = gr.Textbox(label="Relevance")
81
+ confidence_output = gr.Textbox(label="Confidence")
82
+ reasoning_output = gr.Textbox(
83
+ label="Model Reasoning",
84
+ lines=10,
85
+ interactive=False
86
+ )
87
+
88
+ gr.Examples(
89
+ examples=examples,
90
+ inputs=[query_input, passage_input],
91
+ outputs=[relevance_output, confidence_output, reasoning_output],
92
+ fn=partial(process_input, stream=False),
93
+ cache_examples=True,
94
+ )
95
+
96
+ submit_button.click(
97
+ fn=process_input,
98
+ inputs=[query_input, passage_input],
99
+ outputs=[relevance_output, confidence_output, reasoning_output],
100
+ api_name="predict",
101
+ queue=True
102
+ )
103
+
104
+ if __name__ == "__main__":
105
+ demo.launch(share=True)
model.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import math
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, AsyncTextIteratorStreamer
7
+ from transformers import StoppingCriteria, StoppingCriteriaList
8
+ from transformers import AwqConfig, AutoModelForCausalLM
9
+ from threading import Thread
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ThinkStoppingCriteria(StoppingCriteria):
16
+ def __init__(self, tokenizer):
17
+ self.tokenizer = tokenizer
18
+ self.true_sequence = tokenizer("</think> true").input_ids[1:] # Skip first token
19
+ self.false_sequence = tokenizer("</think> false").input_ids[1:] # Skip first token
20
+ self.matched_sequence = None
21
+
22
+ def __call__(self, input_ids, scores, **kwargs):
23
+ for sequence in [self.true_sequence, self.false_sequence]:
24
+ if input_ids.shape[1] >= len(sequence):
25
+ if all((input_ids[0, -(len(sequence)-i)] == sequence[i] for i in range(len(sequence)))):
26
+ self.matched_sequence = "</think> true" if sequence is self.true_sequence else "</think> false"
27
+ return True
28
+ return False
29
+
30
+
31
+ class Rank1:
32
+ def __init__(
33
+ self,
34
+ model_name_or_path: str = "",
35
+ # set these just for demo, typically longer
36
+ context_size: int = 4000,
37
+ max_output_tokens: int = 1024,
38
+ **kwargs,
39
+ ):
40
+ self.context_size = context_size
41
+ self.max_output_tokens = max_output_tokens
42
+
43
+ # Initialize tokenizer
44
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
45
+ self.tokenizer.padding_side = "left"
46
+ self.tokenizer.pad_token = self.tokenizer.eos_token
47
+
48
+ # Cache commonly used token IDs
49
+ self.true_token = self.tokenizer(" true", add_special_tokens=False).input_ids[0]
50
+ self.false_token = self.tokenizer(" false", add_special_tokens=False).input_ids[0]
51
+
52
+ # Load AWQ model on CPU initially
53
+ self.model = AutoModelForCausalLM.from_pretrained(
54
+ model_name_or_path,
55
+ device_map="cpu",
56
+ trust_remote_code=True,
57
+ attn_implementation="flash_attention_2"
58
+ )
59
+
60
+ self.stopping_criteria = StoppingCriteriaList([
61
+ ThinkStoppingCriteria(self.tokenizer)
62
+ ])
63
+
64
+ # Update generation config
65
+ self.generation_config = GenerationConfig(
66
+ max_new_tokens=max_output_tokens,
67
+ do_sample=False,
68
+ pad_token_id=self.tokenizer.pad_token_id,
69
+ eos_token_id=self.tokenizer.eos_token_id
70
+ )
71
+
72
+ # Create text streamer
73
+ self.streamer = TextStreamer(self.tokenizer)
74
+
75
+ # Simple generation config
76
+ self.generation_config = GenerationConfig(
77
+ max_new_tokens=max_output_tokens,
78
+ do_sample=False,
79
+ pad_token_id=self.tokenizer.pad_token_id,
80
+ eos_token_id=self.tokenizer.eos_token_id,
81
+ stopping_sequences=["</think> true", "</think> false"]
82
+ )
83
+
84
+ async def predict(self, query: str, passage: str, streamer=None):
85
+ """Predict relevance of passage to query."""
86
+ prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \
87
+ f"Query: {query}\n" \
88
+ f"Passage: {passage}\n" \
89
+ "<think>"
90
+
91
+ self.model = self.model.to("cuda")
92
+ inputs = self.tokenizer(
93
+ prompt,
94
+ return_tensors="pt",
95
+ truncation=True,
96
+ max_length=self.context_size
97
+ ).to("cuda")
98
+
99
+ if streamer:
100
+ # Create a new streamer for each prediction
101
+ actual_streamer = AsyncTextIteratorStreamer(
102
+ self.tokenizer,
103
+ skip_prompt=True,
104
+ skip_special_tokens=True
105
+ )
106
+
107
+ current_text = "<think>"
108
+
109
+ # Run generation in a separate thread and store the output
110
+ generation_output = None
111
+ def generate_with_output():
112
+ nonlocal generation_output
113
+ generation_output = self.model.generate(
114
+ **inputs,
115
+ generation_config=self.generation_config,
116
+ stopping_criteria=self.stopping_criteria,
117
+ return_dict_in_generate=True,
118
+ output_scores=True,
119
+ streamer=actual_streamer
120
+ )
121
+
122
+ thread = Thread(target=generate_with_output)
123
+ thread.start()
124
+
125
+ # Stream tokens as they're generated
126
+ async for new_text in actual_streamer:
127
+ current_text += new_text
128
+ yield {
129
+ "is_relevant": None,
130
+ "confidence_score": None,
131
+ "model_reasoning": current_text
132
+ }
133
+
134
+ thread.join()
135
+
136
+ # Add the stopping sequence that was matched
137
+ current_text += "\n" + self.stopping_criteria[0].matched_sequence
138
+
139
+ # Calculate final scores using the last scores from generation
140
+ with torch.no_grad():
141
+ final_scores = generation_output.scores[-1][0] # Get logits from last position
142
+ true_logit = final_scores[self.true_token].item()
143
+ false_logit = final_scores[self.false_token].item()
144
+ true_score = math.exp(true_logit)
145
+ false_score = math.exp(false_logit)
146
+ score = true_score / (true_score + false_score)
147
+
148
+ yield {
149
+ "is_relevant": score > 0.5,
150
+ "confidence_score": score,
151
+ "model_reasoning": current_text
152
+ }
153
+ else:
154
+ # Non-streaming mode
155
+ with torch.no_grad():
156
+ outputs = self.model.generate(
157
+ **inputs,
158
+ generation_config=self.generation_config,
159
+ stopping_criteria=self.stopping_criteria,
160
+ return_dict_in_generate=True,
161
+ output_scores=True
162
+ )
163
+
164
+ # Get final score from generation outputs
165
+ final_scores = outputs.scores[-1][0] # Get logits from last position
166
+ true_logit = final_scores[self.true_token].item()
167
+ false_logit = final_scores[self.false_token].item()
168
+ true_score = math.exp(true_logit)
169
+ false_score = math.exp(false_logit)
170
+ score = true_score / (true_score + false_score)
171
+
172
+ # only decode the generated text
173
+ new_text = outputs.sequences[0][len(inputs.input_ids[0]):]
174
+ decoded_input = self.tokenizer.decode(new_text)
175
+ output_reasoning = "<think>\n" + decoded_input.strip() + f"\n</think> {'true' if score > 0.5 else 'false'}"
176
+
177
+ yield {
178
+ "is_relevant": score > 0.5,
179
+ "confidence_score": score,
180
+ "model_reasoning": output_reasoning
181
+ }
182
+
183
+ # Move model back to CPU
184
+ self.model = self.model.to("cpu")
185
+ torch.cuda.empty_cache()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio==5.17.1
2
+ spaces
3
+ transformers==4.49.0
4
+ numpy==1.24.3
5
+ flash_attn==2.6.3
6
+ autoawq==0.2.1
7
+ autoawq_kernels==0.0.9
8
+ torch==2.5.1+cu121