Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
fd247b7
1
Parent(s):
7f28f16
add gen prompt and kwargs dicts
Browse files- utils/models.py +23 -4
utils/models.py
CHANGED
@@ -29,6 +29,8 @@ models = {
|
|
29 |
|
30 |
}
|
31 |
|
|
|
|
|
32 |
# List of model names for easy access
|
33 |
model_names = list(models.keys())
|
34 |
|
@@ -101,13 +103,29 @@ def run_inference(model_name, context, question):
|
|
101 |
|
102 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
103 |
result = ""
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
105 |
if "qwen3" in model_name.lower():
|
106 |
print(f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False.")
|
107 |
-
|
|
|
108 |
|
109 |
try:
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
accepts_sys = (
|
112 |
"System role not supported" not in tokenizer.chat_template
|
113 |
if tokenizer.chat_template else False # Handle missing chat_template
|
@@ -126,6 +144,7 @@ def run_inference(model_name, context, question):
|
|
126 |
tokenizer=tokenizer,
|
127 |
device_map='auto',
|
128 |
trust_remote_code=True,
|
|
|
129 |
)
|
130 |
|
131 |
text_input = format_rag_prompt(question, context, accepts_sys)
|
@@ -134,7 +153,7 @@ def run_inference(model_name, context, question):
|
|
134 |
if generation_interrupt.is_set():
|
135 |
return ""
|
136 |
|
137 |
-
outputs = pipe(text_input, max_new_tokens=512)
|
138 |
result = outputs[0]['generated_text'][-1]['content']
|
139 |
|
140 |
except Exception as e:
|
|
|
29 |
|
30 |
}
|
31 |
|
32 |
+
tokenizer_cache = {}
|
33 |
+
|
34 |
# List of model names for easy access
|
35 |
model_names = list(models.keys())
|
36 |
|
|
|
103 |
|
104 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
105 |
result = ""
|
106 |
+
tokenizer_kwargs = {
|
107 |
+
"add_generation_prompt": True,
|
108 |
+
} # make sure qwen3 doesn't use thinking
|
109 |
+
generation_kwargs = {
|
110 |
+
"max_new_tokens": 512,
|
111 |
+
}
|
112 |
if "qwen3" in model_name.lower():
|
113 |
print(f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False.")
|
114 |
+
tokenizer_kwargs["enable_thinking"] = False
|
115 |
+
generation_kwargs["enable_thinking"] = False
|
116 |
|
117 |
try:
|
118 |
+
if model_name in tokenizer_cache:
|
119 |
+
tokenizer = tokenizer_cache[model_name]
|
120 |
+
else:
|
121 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
122 |
+
model_name,
|
123 |
+
padding_side="left",
|
124 |
+
token=True,
|
125 |
+
kwargs=tokenizer_kwargs
|
126 |
+
)
|
127 |
+
tokenizer_cache[model_name] = tokenizer
|
128 |
+
|
129 |
accepts_sys = (
|
130 |
"System role not supported" not in tokenizer.chat_template
|
131 |
if tokenizer.chat_template else False # Handle missing chat_template
|
|
|
144 |
tokenizer=tokenizer,
|
145 |
device_map='auto',
|
146 |
trust_remote_code=True,
|
147 |
+
torch_dtype=torch.bfloat16,
|
148 |
)
|
149 |
|
150 |
text_input = format_rag_prompt(question, context, accepts_sys)
|
|
|
153 |
if generation_interrupt.is_set():
|
154 |
return ""
|
155 |
|
156 |
+
outputs = pipe(text_input, max_new_tokens=512, generate_kwargs=generation_kwargs)
|
157 |
result = outputs[0]['generated_text'][-1]['content']
|
158 |
|
159 |
except Exception as e:
|