File size: 7,286 Bytes
00134aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776166c
00134aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from __future__ import annotations

import logging
import math
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, AsyncTextIteratorStreamer
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import AwqConfig, AutoModelForCausalLM
from threading import Thread


logger = logging.getLogger(__name__)


class ThinkStoppingCriteria(StoppingCriteria):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.true_sequence = tokenizer("</think> true").input_ids[1:]  # Skip first token
        self.false_sequence = tokenizer("</think> false").input_ids[1:]  # Skip first token
        self.matched_sequence = None
        
    def __call__(self, input_ids, scores, **kwargs):
        for sequence in [self.true_sequence, self.false_sequence]:
            if input_ids.shape[1] >= len(sequence):
                if all((input_ids[0, -(len(sequence)-i)] == sequence[i] for i in range(len(sequence)))):
                    self.matched_sequence = "</think> true" if sequence is self.true_sequence else "</think> false"
                    return True
        return False


class Rank1:
    def __init__(
        self,
        model_name_or_path: str = "",
        # set these just for demo, typically longer
        context_size: int = 4000,
        max_output_tokens: int = 1024,
        **kwargs,
    ):
        self.context_size = context_size
        self.max_output_tokens = max_output_tokens

        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.tokenizer.padding_side = "left"
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # Cache commonly used token IDs
        self.true_token = self.tokenizer(" true", add_special_tokens=False).input_ids[0]
        self.false_token = self.tokenizer(" false", add_special_tokens=False).input_ids[0]

        # Load AWQ model on CPU initially
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            device_map="auto",
            trust_remote_code=True,
            attn_implementation="flash_attention_2" 
        )

        self.stopping_criteria = StoppingCriteriaList([
            ThinkStoppingCriteria(self.tokenizer)
        ])
        
        # Update generation config
        self.generation_config = GenerationConfig(
            max_new_tokens=max_output_tokens,
            do_sample=False,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id
        )

        # Create text streamer
        self.streamer = TextStreamer(self.tokenizer)

        # Simple generation config
        self.generation_config = GenerationConfig(
            max_new_tokens=max_output_tokens,
            do_sample=False,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            stopping_sequences=["</think> true", "</think> false"]
        )

    async def predict(self, query: str, passage: str, streamer=None):
        """Predict relevance of passage to query."""
        prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \
                f"Query: {query}\n" \
                f"Passage: {passage}\n" \
                "<think>"

        self.model = self.model.to("cuda")
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=self.context_size
        ).to("cuda")

        if streamer:
            # Create a new streamer for each prediction
            actual_streamer = AsyncTextIteratorStreamer(
                self.tokenizer,
                skip_prompt=True,
                skip_special_tokens=True
            )
            
            current_text = "<think>"
            
            # Run generation in a separate thread and store the output
            generation_output = None
            def generate_with_output():
                nonlocal generation_output
                generation_output = self.model.generate(
                    **inputs,
                    generation_config=self.generation_config,
                    stopping_criteria=self.stopping_criteria,
                    return_dict_in_generate=True,
                    output_scores=True,
                    streamer=actual_streamer
                )
            
            thread = Thread(target=generate_with_output)
            thread.start()
            
            # Stream tokens as they're generated
            async for new_text in actual_streamer:
                current_text += new_text
                yield {
                    "is_relevant": None,
                    "confidence_score": None,
                    "model_reasoning": current_text
                }
            
            thread.join()
            
            # Add the stopping sequence that was matched
            current_text += "\n" + self.stopping_criteria[0].matched_sequence
            
            # Calculate final scores using the last scores from generation
            with torch.no_grad():
                final_scores = generation_output.scores[-1][0]  # Get logits from last position
                true_logit = final_scores[self.true_token].item()
                false_logit = final_scores[self.false_token].item()
                true_score = math.exp(true_logit)
                false_score = math.exp(false_logit)
                score = true_score / (true_score + false_score)
            
            yield {
                "is_relevant": score > 0.5,
                "confidence_score": score,
                "model_reasoning": current_text
            }
        else:
            # Non-streaming mode
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    generation_config=self.generation_config,
                    stopping_criteria=self.stopping_criteria,
                    return_dict_in_generate=True,
                    output_scores=True
                )
                
                # Get final score from generation outputs
                final_scores = outputs.scores[-1][0]  # Get logits from last position
                true_logit = final_scores[self.true_token].item()
                false_logit = final_scores[self.false_token].item()
                true_score = math.exp(true_logit)
                false_score = math.exp(false_logit)
                score = true_score / (true_score + false_score)

                # only decode the generated text
                new_text = outputs.sequences[0][len(inputs.input_ids[0]):]
                decoded_input = self.tokenizer.decode(new_text)
                output_reasoning = "<think>\n" + decoded_input.strip() + f"\n</think> {'true' if score > 0.5 else 'false'}"
                
                yield {
                    "is_relevant": score > 0.5,
                    "confidence_score": score,
                    "model_reasoning": output_reasoning
                }

        # Move model back to CPU
        self.model = self.model.to("cpu")
        torch.cuda.empty_cache()