oweller2 commited on
Commit
00588f0
·
1 Parent(s): 8aa9a18

working without async

Browse files
Files changed (2) hide show
  1. app.py +97 -17
  2. model.py +1 -101
app.py CHANGED
@@ -1,15 +1,15 @@
1
  import sys
2
  import warnings
3
  import spaces
4
- import asyncio
5
  from threading import Thread
6
- from transformers import AsyncTextIteratorStreamer
7
  from functools import partial
8
 
9
  import gradio as gr
10
  import torch
11
  import numpy as np
12
  from model import Rank1
 
13
 
14
  print(f"NumPy version: {np.__version__}")
15
  print(f"PyTorch version: {torch.__version__}")
@@ -18,22 +18,102 @@ print(f"PyTorch version: {torch.__version__}")
18
  warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML")
19
 
20
  @spaces.GPU
21
- async def process_input(query: str, passage: str, stream: bool = True) -> tuple[str, str, str]:
22
  """Process input through the reranker and return formatted outputs."""
23
- try:
24
- reranker = Rank1(model_name_or_path="orionweller/rank1-32b-awq")
25
- async for result in reranker.predict(query, passage, streamer=stream):
26
- if result["is_relevant"] is None:
27
- # Intermediate streaming result
28
- yield "Processing...", "Processing...", result["model_reasoning"]
29
- else:
30
- # Final result
31
- relevance = "Relevant" if result["is_relevant"] else "Not Relevant"
32
- confidence = f"{result['confidence_score']:.2%}"
33
- reasoning = result["model_reasoning"]
34
- yield relevance, confidence, reasoning
35
- except Exception as e:
36
- yield f"Error: {str(e)}", "N/A", "An error occurred during processing"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Example inputs
39
  examples = [
 
1
  import sys
2
  import warnings
3
  import spaces
 
4
  from threading import Thread
5
+ from transformers import TextIteratorStreamer
6
  from functools import partial
7
 
8
  import gradio as gr
9
  import torch
10
  import numpy as np
11
  from model import Rank1
12
+ import math
13
 
14
  print(f"NumPy version: {np.__version__}")
15
  print(f"PyTorch version: {torch.__version__}")
 
18
  warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML")
19
 
20
  @spaces.GPU
21
+ def process_input(query: str, passage: str, stream: bool = True) -> tuple[str, str, str]:
22
  """Process input through the reranker and return formatted outputs."""
23
+ reranker = Rank1(model_name_or_path="orionweller/rank1-32b-awq")
24
+ prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \
25
+ f"Query: {query}\n" \
26
+ f"Passage: {passage}\n" \
27
+ "<think>"
28
+
29
+ reranker.model = reranker.model.to("cuda")
30
+ inputs = reranker.tokenizer(
31
+ prompt,
32
+ return_tensors="pt",
33
+ truncation=True,
34
+ max_length=reranker.context_size
35
+ ).to("cuda")
36
+
37
+ if stream:
38
+ streamer = TextIteratorStreamer(
39
+ reranker.tokenizer,
40
+ skip_prompt=True,
41
+ skip_special_tokens=True
42
+ )
43
+
44
+ current_text = "<think>"
45
+ generation_output = None
46
+
47
+ def generate_with_output():
48
+ nonlocal generation_output
49
+ generation_output = reranker.model.generate(
50
+ **inputs,
51
+ generation_config=reranker.generation_config,
52
+ stopping_criteria=reranker.stopping_criteria,
53
+ return_dict_in_generate=True,
54
+ output_scores=True,
55
+ streamer=streamer
56
+ )
57
+
58
+ thread = Thread(target=generate_with_output)
59
+ thread.start()
60
+
61
+ # Stream tokens as they're generated
62
+ for new_text in streamer:
63
+ current_text += new_text
64
+ yield (
65
+ "Processing...",
66
+ "Processing...",
67
+ current_text
68
+ )
69
+
70
+ thread.join()
71
+
72
+ # Add the stopping sequence and calculate final scores
73
+ current_text += "\n" + reranker.stopping_criteria[0].matched_sequence
74
+
75
+ with torch.no_grad():
76
+ final_scores = generation_output.scores[-1][0]
77
+ true_logit = final_scores[reranker.true_token].item()
78
+ false_logit = final_scores[reranker.false_token].item()
79
+ true_score = math.exp(true_logit)
80
+ false_score = math.exp(false_logit)
81
+ score = true_score / (true_score + false_score)
82
+
83
+ yield (
84
+ score > 0.5,
85
+ score,
86
+ current_text
87
+ )
88
+ else:
89
+ # Non-streaming mode
90
+ with torch.no_grad():
91
+ outputs = reranker.model.generate(
92
+ **inputs,
93
+ generation_config=reranker.generation_config,
94
+ stopping_criteria=reranker.stopping_criteria,
95
+ return_dict_in_generate=True,
96
+ output_scores=True
97
+ )
98
+
99
+ # Get final score from generation outputs
100
+ final_scores = outputs.scores[-1][0] # Get logits from last position
101
+ true_logit = final_scores[reranker.true_token].item()
102
+ false_logit = final_scores[reranker.false_token].item()
103
+ true_score = math.exp(true_logit)
104
+ false_score = math.exp(false_logit)
105
+ score = true_score / (true_score + false_score)
106
+
107
+ # only decode the generated text
108
+ new_text = outputs.sequences[0][len(inputs.input_ids[0]):]
109
+ decoded_input = reranker.tokenizer.decode(new_text)
110
+ output_reasoning = "<think>\n" + decoded_input.strip() + f"\n</think> {'true' if score > 0.5 else 'false'}"
111
+
112
+ yield (
113
+ "Relevant" if score > 0.5 else "Not Relevant",
114
+ f"{score:.2%}",
115
+ output_reasoning
116
+ )
117
 
118
  # Example inputs
119
  examples = [
model.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
  import logging
4
  import math
5
  import torch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, AsyncTextIteratorStreamer, TextIteratorStreamer
7
  from transformers import StoppingCriteria, StoppingCriteriaList
8
  from transformers import AwqConfig, AutoModelForCausalLM
9
  from threading import Thread
@@ -80,103 +80,3 @@ class Rank1:
80
  eos_token_id=self.tokenizer.eos_token_id,
81
  stopping_sequences=["</think> true", "</think> false"]
82
  )
83
-
84
- def predict(self, query: str, passage: str, stream: bool = False):
85
- """Predict relevance of passage to query."""
86
- prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \
87
- f"Query: {query}\n" \
88
- f"Passage: {passage}\n" \
89
- "<think>"
90
-
91
- self.model = self.model.to("cuda")
92
- inputs = self.tokenizer(
93
- prompt,
94
- return_tensors="pt",
95
- truncation=True,
96
- max_length=self.context_size
97
- ).to("cuda")
98
-
99
- if stream:
100
- streamer = TextIteratorStreamer(
101
- self.tokenizer,
102
- skip_prompt=True,
103
- skip_special_tokens=True
104
- )
105
-
106
- current_text = "<think>"
107
- generation_output = None
108
-
109
- def generate_with_output():
110
- nonlocal generation_output
111
- generation_output = self.model.generate(
112
- **inputs,
113
- generation_config=self.generation_config,
114
- stopping_criteria=self.stopping_criteria,
115
- return_dict_in_generate=True,
116
- output_scores=True,
117
- streamer=streamer
118
- )
119
-
120
- thread = Thread(target=generate_with_output)
121
- thread.start()
122
-
123
- # Stream tokens as they're generated
124
- for new_text in streamer:
125
- current_text += new_text
126
- yield {
127
- "is_relevant": None,
128
- "confidence_score": None,
129
- "model_reasoning": current_text
130
- }
131
-
132
- thread.join()
133
-
134
- # Add the stopping sequence and calculate final scores
135
- current_text += "\n" + self.stopping_criteria[0].matched_sequence
136
-
137
- with torch.no_grad():
138
- final_scores = generation_output.scores[-1][0]
139
- true_logit = final_scores[self.true_token].item()
140
- false_logit = final_scores[self.false_token].item()
141
- true_score = math.exp(true_logit)
142
- false_score = math.exp(false_logit)
143
- score = true_score / (true_score + false_score)
144
-
145
- yield {
146
- "is_relevant": score > 0.5,
147
- "confidence_score": score,
148
- "model_reasoning": current_text
149
- }
150
- else:
151
- # Non-streaming mode
152
- with torch.no_grad():
153
- outputs = self.model.generate(
154
- **inputs,
155
- generation_config=self.generation_config,
156
- stopping_criteria=self.stopping_criteria,
157
- return_dict_in_generate=True,
158
- output_scores=True
159
- )
160
-
161
- # Get final score from generation outputs
162
- final_scores = outputs.scores[-1][0] # Get logits from last position
163
- true_logit = final_scores[self.true_token].item()
164
- false_logit = final_scores[self.false_token].item()
165
- true_score = math.exp(true_logit)
166
- false_score = math.exp(false_logit)
167
- score = true_score / (true_score + false_score)
168
-
169
- # only decode the generated text
170
- new_text = outputs.sequences[0][len(inputs.input_ids[0]):]
171
- decoded_input = self.tokenizer.decode(new_text)
172
- output_reasoning = "<think>\n" + decoded_input.strip() + f"\n</think> {'true' if score > 0.5 else 'false'}"
173
-
174
- yield {
175
- "is_relevant": score > 0.5,
176
- "confidence_score": score,
177
- "model_reasoning": output_reasoning
178
- }
179
-
180
- # Move model back to CPU
181
- self.model = self.model.to("cpu")
182
- torch.cuda.empty_cache()
 
3
  import logging
4
  import math
5
  import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, TextIteratorStreamer
7
  from transformers import StoppingCriteria, StoppingCriteriaList
8
  from transformers import AwqConfig, AutoModelForCausalLM
9
  from threading import Thread
 
80
  eos_token_id=self.tokenizer.eos_token_id,
81
  stopping_sequences=["</think> true", "</think> false"]
82
  )