oliver-aizip commited on
Commit
fd247b7
·
1 Parent(s): 7f28f16

add gen prompt and kwargs dicts

Browse files
Files changed (1) hide show
  1. utils/models.py +23 -4
utils/models.py CHANGED
@@ -29,6 +29,8 @@ models = {
29
 
30
  }
31
 
 
 
32
  # List of model names for easy access
33
  model_names = list(models.keys())
34
 
@@ -101,13 +103,29 @@ def run_inference(model_name, context, question):
101
 
102
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
  result = ""
104
- model_kwargs = {} # make sure qwen3 doesn't use thinking
 
 
 
 
 
105
  if "qwen3" in model_name.lower():
106
  print(f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False.")
107
- model_kwargs["enable_thinking"] = False
 
108
 
109
  try:
110
- tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True, kwargs=model_kwargs)
 
 
 
 
 
 
 
 
 
 
111
  accepts_sys = (
112
  "System role not supported" not in tokenizer.chat_template
113
  if tokenizer.chat_template else False # Handle missing chat_template
@@ -126,6 +144,7 @@ def run_inference(model_name, context, question):
126
  tokenizer=tokenizer,
127
  device_map='auto',
128
  trust_remote_code=True,
 
129
  )
130
 
131
  text_input = format_rag_prompt(question, context, accepts_sys)
@@ -134,7 +153,7 @@ def run_inference(model_name, context, question):
134
  if generation_interrupt.is_set():
135
  return ""
136
 
137
- outputs = pipe(text_input, max_new_tokens=512)
138
  result = outputs[0]['generated_text'][-1]['content']
139
 
140
  except Exception as e:
 
29
 
30
  }
31
 
32
+ tokenizer_cache = {}
33
+
34
  # List of model names for easy access
35
  model_names = list(models.keys())
36
 
 
103
 
104
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
105
  result = ""
106
+ tokenizer_kwargs = {
107
+ "add_generation_prompt": True,
108
+ } # make sure qwen3 doesn't use thinking
109
+ generation_kwargs = {
110
+ "max_new_tokens": 512,
111
+ }
112
  if "qwen3" in model_name.lower():
113
  print(f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False.")
114
+ tokenizer_kwargs["enable_thinking"] = False
115
+ generation_kwargs["enable_thinking"] = False
116
 
117
  try:
118
+ if model_name in tokenizer_cache:
119
+ tokenizer = tokenizer_cache[model_name]
120
+ else:
121
+ tokenizer = AutoTokenizer.from_pretrained(
122
+ model_name,
123
+ padding_side="left",
124
+ token=True,
125
+ kwargs=tokenizer_kwargs
126
+ )
127
+ tokenizer_cache[model_name] = tokenizer
128
+
129
  accepts_sys = (
130
  "System role not supported" not in tokenizer.chat_template
131
  if tokenizer.chat_template else False # Handle missing chat_template
 
144
  tokenizer=tokenizer,
145
  device_map='auto',
146
  trust_remote_code=True,
147
+ torch_dtype=torch.bfloat16,
148
  )
149
 
150
  text_input = format_rag_prompt(question, context, accepts_sys)
 
153
  if generation_interrupt.is_set():
154
  return ""
155
 
156
+ outputs = pipe(text_input, max_new_tokens=512, generate_kwargs=generation_kwargs)
157
  result = outputs[0]['generated_text'][-1]['content']
158
 
159
  except Exception as e: