oweller2 commited on
Commit
8aa9a18
·
1 Parent(s): 2413d91

move away from async

Browse files
Files changed (1) hide show
  1. model.py +9 -12
model.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
  import logging
4
  import math
5
  import torch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, AsyncTextIteratorStreamer
7
  from transformers import StoppingCriteria, StoppingCriteriaList
8
  from transformers import AwqConfig, AutoModelForCausalLM
9
  from threading import Thread
@@ -81,7 +81,7 @@ class Rank1:
81
  stopping_sequences=["</think> true", "</think> false"]
82
  )
83
 
84
- async def predict(self, query: str, passage: str, streamer=None):
85
  """Predict relevance of passage to query."""
86
  prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \
87
  f"Query: {query}\n" \
@@ -96,18 +96,16 @@ class Rank1:
96
  max_length=self.context_size
97
  ).to("cuda")
98
 
99
- if streamer:
100
- # Create a new streamer for each prediction
101
- actual_streamer = AsyncTextIteratorStreamer(
102
  self.tokenizer,
103
  skip_prompt=True,
104
  skip_special_tokens=True
105
  )
106
 
107
  current_text = "<think>"
108
-
109
- # Run generation in a separate thread and store the output
110
  generation_output = None
 
111
  def generate_with_output():
112
  nonlocal generation_output
113
  generation_output = self.model.generate(
@@ -116,14 +114,14 @@ class Rank1:
116
  stopping_criteria=self.stopping_criteria,
117
  return_dict_in_generate=True,
118
  output_scores=True,
119
- streamer=actual_streamer
120
  )
121
 
122
  thread = Thread(target=generate_with_output)
123
  thread.start()
124
 
125
  # Stream tokens as they're generated
126
- async for new_text in actual_streamer:
127
  current_text += new_text
128
  yield {
129
  "is_relevant": None,
@@ -133,12 +131,11 @@ class Rank1:
133
 
134
  thread.join()
135
 
136
- # Add the stopping sequence that was matched
137
  current_text += "\n" + self.stopping_criteria[0].matched_sequence
138
 
139
- # Calculate final scores using the last scores from generation
140
  with torch.no_grad():
141
- final_scores = generation_output.scores[-1][0] # Get logits from last position
142
  true_logit = final_scores[self.true_token].item()
143
  false_logit = final_scores[self.false_token].item()
144
  true_score = math.exp(true_logit)
 
3
  import logging
4
  import math
5
  import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, AsyncTextIteratorStreamer, TextIteratorStreamer
7
  from transformers import StoppingCriteria, StoppingCriteriaList
8
  from transformers import AwqConfig, AutoModelForCausalLM
9
  from threading import Thread
 
81
  stopping_sequences=["</think> true", "</think> false"]
82
  )
83
 
84
+ def predict(self, query: str, passage: str, stream: bool = False):
85
  """Predict relevance of passage to query."""
86
  prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \
87
  f"Query: {query}\n" \
 
96
  max_length=self.context_size
97
  ).to("cuda")
98
 
99
+ if stream:
100
+ streamer = TextIteratorStreamer(
 
101
  self.tokenizer,
102
  skip_prompt=True,
103
  skip_special_tokens=True
104
  )
105
 
106
  current_text = "<think>"
 
 
107
  generation_output = None
108
+
109
  def generate_with_output():
110
  nonlocal generation_output
111
  generation_output = self.model.generate(
 
114
  stopping_criteria=self.stopping_criteria,
115
  return_dict_in_generate=True,
116
  output_scores=True,
117
+ streamer=streamer
118
  )
119
 
120
  thread = Thread(target=generate_with_output)
121
  thread.start()
122
 
123
  # Stream tokens as they're generated
124
+ for new_text in streamer:
125
  current_text += new_text
126
  yield {
127
  "is_relevant": None,
 
131
 
132
  thread.join()
133
 
134
+ # Add the stopping sequence and calculate final scores
135
  current_text += "\n" + self.stopping_criteria[0].matched_sequence
136
 
 
137
  with torch.no_grad():
138
+ final_scores = generation_output.scores[-1][0]
139
  true_logit = final_scores[self.true_token].item()
140
  false_logit = final_scores[self.false_token].item()
141
  true_score = math.exp(true_logit)