Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import logging | |
import math | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, TextIteratorStreamer | |
from transformers import StoppingCriteria, StoppingCriteriaList | |
from transformers import AwqConfig, AutoModelForCausalLM | |
from threading import Thread | |
logger = logging.getLogger(__name__) | |
class ThinkStoppingCriteria(StoppingCriteria): | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
self.true_sequence = tokenizer("</think> true").input_ids[1:] # Skip first token | |
self.false_sequence = tokenizer("</think> false").input_ids[1:] # Skip first token | |
self.matched_sequence = None | |
def __call__(self, input_ids, scores, **kwargs): | |
for sequence in [self.true_sequence, self.false_sequence]: | |
if input_ids.shape[1] >= len(sequence): | |
if all((input_ids[0, -(len(sequence)-i)] == sequence[i] for i in range(len(sequence)))): | |
self.matched_sequence = "</think> true" if sequence is self.true_sequence else "</think> false" | |
return True | |
return False | |
class Rank1: | |
def __init__( | |
self, | |
model_name_or_path: str = "", | |
# set these just for demo, typically longer | |
context_size: int = 4000, | |
max_output_tokens: int = 1024, | |
**kwargs, | |
): | |
self.context_size = context_size | |
self.max_output_tokens = max_output_tokens | |
# Initialize tokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
self.tokenizer.padding_side = "left" | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Cache commonly used token IDs | |
self.true_token = self.tokenizer(" true", add_special_tokens=False).input_ids[0] | |
self.false_token = self.tokenizer(" false", add_special_tokens=False).input_ids[0] | |
# Load AWQ model on CPU initially | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, | |
device_map="auto", | |
trust_remote_code=True, | |
attn_implementation="flash_attention_2" | |
) | |
self.stopping_criteria = StoppingCriteriaList([ | |
ThinkStoppingCriteria(self.tokenizer) | |
]) | |
# Update generation config | |
self.generation_config = GenerationConfig( | |
max_new_tokens=max_output_tokens, | |
do_sample=False, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
# Create text streamer | |
self.streamer = TextStreamer(self.tokenizer) | |
# Simple generation config | |
self.generation_config = GenerationConfig( | |
max_new_tokens=max_output_tokens, | |
do_sample=False, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id, | |
stopping_sequences=["</think> true", "</think> false"] | |
) | |