from __future__ import annotations import logging import math import torch from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, TextIteratorStreamer from transformers import StoppingCriteria, StoppingCriteriaList from transformers import AwqConfig, AutoModelForCausalLM from threading import Thread logger = logging.getLogger(__name__) class ThinkStoppingCriteria(StoppingCriteria): def __init__(self, tokenizer): self.tokenizer = tokenizer self.true_sequence = tokenizer(" true").input_ids[1:] # Skip first token self.false_sequence = tokenizer(" false").input_ids[1:] # Skip first token self.matched_sequence = None def __call__(self, input_ids, scores, **kwargs): for sequence in [self.true_sequence, self.false_sequence]: if input_ids.shape[1] >= len(sequence): if all((input_ids[0, -(len(sequence)-i)] == sequence[i] for i in range(len(sequence)))): self.matched_sequence = " true" if sequence is self.true_sequence else " false" return True return False class Rank1: def __init__( self, model_name_or_path: str = "", # set these just for demo, typically longer context_size: int = 4000, max_output_tokens: int = 1024, **kwargs, ): self.context_size = context_size self.max_output_tokens = max_output_tokens # Initialize tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.tokenizer.padding_side = "left" self.tokenizer.pad_token = self.tokenizer.eos_token # Cache commonly used token IDs self.true_token = self.tokenizer(" true", add_special_tokens=False).input_ids[0] self.false_token = self.tokenizer(" false", add_special_tokens=False).input_ids[0] # Load AWQ model on CPU initially self.model = AutoModelForCausalLM.from_pretrained( model_name_or_path, device_map="auto", trust_remote_code=True, attn_implementation="flash_attention_2" ) self.stopping_criteria = StoppingCriteriaList([ ThinkStoppingCriteria(self.tokenizer) ]) # Update generation config self.generation_config = GenerationConfig( max_new_tokens=max_output_tokens, do_sample=False, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) # Create text streamer self.streamer = TextStreamer(self.tokenizer) # Simple generation config self.generation_config = GenerationConfig( max_new_tokens=max_output_tokens, do_sample=False, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, stopping_sequences=[" true", " false"] )