from __future__ import annotations import logging import math import torch from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, AsyncTextIteratorStreamer, TextIteratorStreamer from transformers import StoppingCriteria, StoppingCriteriaList from transformers import AwqConfig, AutoModelForCausalLM from threading import Thread logger = logging.getLogger(__name__) class ThinkStoppingCriteria(StoppingCriteria): def __init__(self, tokenizer): self.tokenizer = tokenizer self.true_sequence = tokenizer(" true").input_ids[1:] # Skip first token self.false_sequence = tokenizer(" false").input_ids[1:] # Skip first token self.matched_sequence = None def __call__(self, input_ids, scores, **kwargs): for sequence in [self.true_sequence, self.false_sequence]: if input_ids.shape[1] >= len(sequence): if all((input_ids[0, -(len(sequence)-i)] == sequence[i] for i in range(len(sequence)))): self.matched_sequence = " true" if sequence is self.true_sequence else " false" return True return False class Rank1: def __init__( self, model_name_or_path: str = "", # set these just for demo, typically longer context_size: int = 4000, max_output_tokens: int = 1024, **kwargs, ): self.context_size = context_size self.max_output_tokens = max_output_tokens # Initialize tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.tokenizer.padding_side = "left" self.tokenizer.pad_token = self.tokenizer.eos_token # Cache commonly used token IDs self.true_token = self.tokenizer(" true", add_special_tokens=False).input_ids[0] self.false_token = self.tokenizer(" false", add_special_tokens=False).input_ids[0] # Load AWQ model on CPU initially self.model = AutoModelForCausalLM.from_pretrained( model_name_or_path, device_map="auto", trust_remote_code=True, attn_implementation="flash_attention_2" ) self.stopping_criteria = StoppingCriteriaList([ ThinkStoppingCriteria(self.tokenizer) ]) # Update generation config self.generation_config = GenerationConfig( max_new_tokens=max_output_tokens, do_sample=False, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) # Create text streamer self.streamer = TextStreamer(self.tokenizer) # Simple generation config self.generation_config = GenerationConfig( max_new_tokens=max_output_tokens, do_sample=False, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, stopping_sequences=[" true", " false"] ) def predict(self, query: str, passage: str, stream: bool = False): """Predict relevance of passage to query.""" prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \ f"Query: {query}\n" \ f"Passage: {passage}\n" \ "" self.model = self.model.to("cuda") inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=self.context_size ).to("cuda") if stream: streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True ) current_text = "" generation_output = None def generate_with_output(): nonlocal generation_output generation_output = self.model.generate( **inputs, generation_config=self.generation_config, stopping_criteria=self.stopping_criteria, return_dict_in_generate=True, output_scores=True, streamer=streamer ) thread = Thread(target=generate_with_output) thread.start() # Stream tokens as they're generated for new_text in streamer: current_text += new_text yield { "is_relevant": None, "confidence_score": None, "model_reasoning": current_text } thread.join() # Add the stopping sequence and calculate final scores current_text += "\n" + self.stopping_criteria[0].matched_sequence with torch.no_grad(): final_scores = generation_output.scores[-1][0] true_logit = final_scores[self.true_token].item() false_logit = final_scores[self.false_token].item() true_score = math.exp(true_logit) false_score = math.exp(false_logit) score = true_score / (true_score + false_score) yield { "is_relevant": score > 0.5, "confidence_score": score, "model_reasoning": current_text } else: # Non-streaming mode with torch.no_grad(): outputs = self.model.generate( **inputs, generation_config=self.generation_config, stopping_criteria=self.stopping_criteria, return_dict_in_generate=True, output_scores=True ) # Get final score from generation outputs final_scores = outputs.scores[-1][0] # Get logits from last position true_logit = final_scores[self.true_token].item() false_logit = final_scores[self.false_token].item() true_score = math.exp(true_logit) false_score = math.exp(false_logit) score = true_score / (true_score + false_score) # only decode the generated text new_text = outputs.sequences[0][len(inputs.input_ids[0]):] decoded_input = self.tokenizer.decode(new_text) output_reasoning = "\n" + decoded_input.strip() + f"\n {'true' if score > 0.5 else 'false'}" yield { "is_relevant": score > 0.5, "confidence_score": score, "model_reasoning": output_reasoning } # Move model back to CPU self.model = self.model.to("cpu") torch.cuda.empty_cache()