Spaces:
Running
on
Zero
Running
on
Zero
oweller2
commited on
Commit
·
8aa9a18
1
Parent(s):
2413d91
move away from async
Browse files
model.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3 |
import logging
|
4 |
import math
|
5 |
import torch
|
6 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, AsyncTextIteratorStreamer
|
7 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
8 |
from transformers import AwqConfig, AutoModelForCausalLM
|
9 |
from threading import Thread
|
@@ -81,7 +81,7 @@ class Rank1:
|
|
81 |
stopping_sequences=["</think> true", "</think> false"]
|
82 |
)
|
83 |
|
84 |
-
|
85 |
"""Predict relevance of passage to query."""
|
86 |
prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \
|
87 |
f"Query: {query}\n" \
|
@@ -96,18 +96,16 @@ class Rank1:
|
|
96 |
max_length=self.context_size
|
97 |
).to("cuda")
|
98 |
|
99 |
-
if
|
100 |
-
|
101 |
-
actual_streamer = AsyncTextIteratorStreamer(
|
102 |
self.tokenizer,
|
103 |
skip_prompt=True,
|
104 |
skip_special_tokens=True
|
105 |
)
|
106 |
|
107 |
current_text = "<think>"
|
108 |
-
|
109 |
-
# Run generation in a separate thread and store the output
|
110 |
generation_output = None
|
|
|
111 |
def generate_with_output():
|
112 |
nonlocal generation_output
|
113 |
generation_output = self.model.generate(
|
@@ -116,14 +114,14 @@ class Rank1:
|
|
116 |
stopping_criteria=self.stopping_criteria,
|
117 |
return_dict_in_generate=True,
|
118 |
output_scores=True,
|
119 |
-
streamer=
|
120 |
)
|
121 |
|
122 |
thread = Thread(target=generate_with_output)
|
123 |
thread.start()
|
124 |
|
125 |
# Stream tokens as they're generated
|
126 |
-
|
127 |
current_text += new_text
|
128 |
yield {
|
129 |
"is_relevant": None,
|
@@ -133,12 +131,11 @@ class Rank1:
|
|
133 |
|
134 |
thread.join()
|
135 |
|
136 |
-
# Add the stopping sequence
|
137 |
current_text += "\n" + self.stopping_criteria[0].matched_sequence
|
138 |
|
139 |
-
# Calculate final scores using the last scores from generation
|
140 |
with torch.no_grad():
|
141 |
-
final_scores = generation_output.scores[-1][0]
|
142 |
true_logit = final_scores[self.true_token].item()
|
143 |
false_logit = final_scores[self.false_token].item()
|
144 |
true_score = math.exp(true_logit)
|
|
|
3 |
import logging
|
4 |
import math
|
5 |
import torch
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, AsyncTextIteratorStreamer, TextIteratorStreamer
|
7 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
8 |
from transformers import AwqConfig, AutoModelForCausalLM
|
9 |
from threading import Thread
|
|
|
81 |
stopping_sequences=["</think> true", "</think> false"]
|
82 |
)
|
83 |
|
84 |
+
def predict(self, query: str, passage: str, stream: bool = False):
|
85 |
"""Predict relevance of passage to query."""
|
86 |
prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \
|
87 |
f"Query: {query}\n" \
|
|
|
96 |
max_length=self.context_size
|
97 |
).to("cuda")
|
98 |
|
99 |
+
if stream:
|
100 |
+
streamer = TextIteratorStreamer(
|
|
|
101 |
self.tokenizer,
|
102 |
skip_prompt=True,
|
103 |
skip_special_tokens=True
|
104 |
)
|
105 |
|
106 |
current_text = "<think>"
|
|
|
|
|
107 |
generation_output = None
|
108 |
+
|
109 |
def generate_with_output():
|
110 |
nonlocal generation_output
|
111 |
generation_output = self.model.generate(
|
|
|
114 |
stopping_criteria=self.stopping_criteria,
|
115 |
return_dict_in_generate=True,
|
116 |
output_scores=True,
|
117 |
+
streamer=streamer
|
118 |
)
|
119 |
|
120 |
thread = Thread(target=generate_with_output)
|
121 |
thread.start()
|
122 |
|
123 |
# Stream tokens as they're generated
|
124 |
+
for new_text in streamer:
|
125 |
current_text += new_text
|
126 |
yield {
|
127 |
"is_relevant": None,
|
|
|
131 |
|
132 |
thread.join()
|
133 |
|
134 |
+
# Add the stopping sequence and calculate final scores
|
135 |
current_text += "\n" + self.stopping_criteria[0].matched_sequence
|
136 |
|
|
|
137 |
with torch.no_grad():
|
138 |
+
final_scores = generation_output.scores[-1][0]
|
139 |
true_logit = final_scores[self.true_token].item()
|
140 |
false_logit = final_scores[self.false_token].item()
|
141 |
true_score = math.exp(true_logit)
|