Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import logging | |
import math | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, AsyncTextIteratorStreamer | |
from transformers import StoppingCriteria, StoppingCriteriaList | |
from transformers import AwqConfig, AutoModelForCausalLM | |
from threading import Thread | |
logger = logging.getLogger(__name__) | |
class ThinkStoppingCriteria(StoppingCriteria): | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
self.true_sequence = tokenizer("</think> true").input_ids[1:] # Skip first token | |
self.false_sequence = tokenizer("</think> false").input_ids[1:] # Skip first token | |
self.matched_sequence = None | |
def __call__(self, input_ids, scores, **kwargs): | |
for sequence in [self.true_sequence, self.false_sequence]: | |
if input_ids.shape[1] >= len(sequence): | |
if all((input_ids[0, -(len(sequence)-i)] == sequence[i] for i in range(len(sequence)))): | |
self.matched_sequence = "</think> true" if sequence is self.true_sequence else "</think> false" | |
return True | |
return False | |
class Rank1: | |
def __init__( | |
self, | |
model_name_or_path: str = "", | |
# set these just for demo, typically longer | |
context_size: int = 4000, | |
max_output_tokens: int = 1024, | |
**kwargs, | |
): | |
self.context_size = context_size | |
self.max_output_tokens = max_output_tokens | |
# Initialize tokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
self.tokenizer.padding_side = "left" | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Cache commonly used token IDs | |
self.true_token = self.tokenizer(" true", add_special_tokens=False).input_ids[0] | |
self.false_token = self.tokenizer(" false", add_special_tokens=False).input_ids[0] | |
# Load AWQ model on CPU initially | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, | |
device_map="auto", | |
trust_remote_code=True, | |
attn_implementation="flash_attention_2" | |
) | |
self.stopping_criteria = StoppingCriteriaList([ | |
ThinkStoppingCriteria(self.tokenizer) | |
]) | |
# Update generation config | |
self.generation_config = GenerationConfig( | |
max_new_tokens=max_output_tokens, | |
do_sample=False, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
# Create text streamer | |
self.streamer = TextStreamer(self.tokenizer) | |
# Simple generation config | |
self.generation_config = GenerationConfig( | |
max_new_tokens=max_output_tokens, | |
do_sample=False, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id, | |
stopping_sequences=["</think> true", "</think> false"] | |
) | |
async def predict(self, query: str, passage: str, streamer=None): | |
"""Predict relevance of passage to query.""" | |
prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \ | |
f"Query: {query}\n" \ | |
f"Passage: {passage}\n" \ | |
"<think>" | |
self.model = self.model.to("cuda") | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=self.context_size | |
).to("cuda") | |
if streamer: | |
# Create a new streamer for each prediction | |
actual_streamer = AsyncTextIteratorStreamer( | |
self.tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
current_text = "<think>" | |
# Run generation in a separate thread and store the output | |
generation_output = None | |
def generate_with_output(): | |
nonlocal generation_output | |
generation_output = self.model.generate( | |
**inputs, | |
generation_config=self.generation_config, | |
stopping_criteria=self.stopping_criteria, | |
return_dict_in_generate=True, | |
output_scores=True, | |
streamer=actual_streamer | |
) | |
thread = Thread(target=generate_with_output) | |
thread.start() | |
# Stream tokens as they're generated | |
async for new_text in actual_streamer: | |
current_text += new_text | |
yield { | |
"is_relevant": None, | |
"confidence_score": None, | |
"model_reasoning": current_text | |
} | |
thread.join() | |
# Add the stopping sequence that was matched | |
current_text += "\n" + self.stopping_criteria[0].matched_sequence | |
# Calculate final scores using the last scores from generation | |
with torch.no_grad(): | |
final_scores = generation_output.scores[-1][0] # Get logits from last position | |
true_logit = final_scores[self.true_token].item() | |
false_logit = final_scores[self.false_token].item() | |
true_score = math.exp(true_logit) | |
false_score = math.exp(false_logit) | |
score = true_score / (true_score + false_score) | |
yield { | |
"is_relevant": score > 0.5, | |
"confidence_score": score, | |
"model_reasoning": current_text | |
} | |
else: | |
# Non-streaming mode | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
generation_config=self.generation_config, | |
stopping_criteria=self.stopping_criteria, | |
return_dict_in_generate=True, | |
output_scores=True | |
) | |
# Get final score from generation outputs | |
final_scores = outputs.scores[-1][0] # Get logits from last position | |
true_logit = final_scores[self.true_token].item() | |
false_logit = final_scores[self.false_token].item() | |
true_score = math.exp(true_logit) | |
false_score = math.exp(false_logit) | |
score = true_score / (true_score + false_score) | |
# only decode the generated text | |
new_text = outputs.sequences[0][len(inputs.input_ids[0]):] | |
decoded_input = self.tokenizer.decode(new_text) | |
output_reasoning = "<think>\n" + decoded_input.strip() + f"\n</think> {'true' if score > 0.5 else 'false'}" | |
yield { | |
"is_relevant": score > 0.5, | |
"confidence_score": score, | |
"model_reasoning": output_reasoning | |
} | |
# Move model back to CPU | |
self.model = self.model.to("cpu") | |
torch.cuda.empty_cache() | |