Spaces:
Sleeping
Sleeping
Commit
·
1559fe0
1
Parent(s):
728a621
Update transformers to 4.39.3 and optimize model loading
Browse files- app.py +9 -7
- requirements.txt +1 -1
app.py
CHANGED
@@ -36,6 +36,7 @@ model_name_or_path = "xiaoxishui/internlm2_5-7b-chat"
|
|
36 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
37 |
print(f"Using device: {device}")
|
38 |
|
|
|
39 |
@dataclass
|
40 |
class GenerationConfig:
|
41 |
# this config is used for chat to provide more diversity
|
@@ -187,7 +188,10 @@ def on_btn_click():
|
|
187 |
def load_model():
|
188 |
model = (AutoModelForCausalLM.from_pretrained(
|
189 |
model_name_or_path,
|
190 |
-
trust_remote_code=True
|
|
|
|
|
|
|
191 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
|
192 |
trust_remote_code=True)
|
193 |
return model, tokenizer
|
@@ -210,17 +214,16 @@ def prepare_generation_config():
|
|
210 |
return generation_config
|
211 |
|
212 |
|
213 |
-
user_prompt = '
|
214 |
-
robot_prompt = '
|
215 |
-
cur_query_prompt = '
|
216 |
-
<|im_start|>assistant\n'
|
217 |
|
218 |
|
219 |
def combine_history(prompt):
|
220 |
messages = st.session_state.messages
|
221 |
meta_instruction = ('You are a helpful, honest, '
|
222 |
'and harmless AI assistant.')
|
223 |
-
total_prompt = f'<s
|
224 |
for message in messages:
|
225 |
cur_content = message['content']
|
226 |
if message['role'] == 'user':
|
@@ -293,4 +296,3 @@ def main():
|
|
293 |
|
294 |
if __name__ == '__main__':
|
295 |
main()
|
296 |
-
|
|
|
36 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
37 |
print(f"Using device: {device}")
|
38 |
|
39 |
+
|
40 |
@dataclass
|
41 |
class GenerationConfig:
|
42 |
# this config is used for chat to provide more diversity
|
|
|
188 |
def load_model():
|
189 |
model = (AutoModelForCausalLM.from_pretrained(
|
190 |
model_name_or_path,
|
191 |
+
trust_remote_code=True,
|
192 |
+
use_cache=False, # 禁用 KV 缓存
|
193 |
+
torch_dtype=torch.bfloat16,
|
194 |
+
device_map="auto")).cuda()
|
195 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
|
196 |
trust_remote_code=True)
|
197 |
return model, tokenizer
|
|
|
214 |
return generation_config
|
215 |
|
216 |
|
217 |
+
user_prompt = '👥\n{user}\n'
|
218 |
+
robot_prompt = '🤖\n{robot}\n'
|
219 |
+
cur_query_prompt = '👥\n{user}\n'
|
|
|
220 |
|
221 |
|
222 |
def combine_history(prompt):
|
223 |
messages = st.session_state.messages
|
224 |
meta_instruction = ('You are a helpful, honest, '
|
225 |
'and harmless AI assistant.')
|
226 |
+
total_prompt = f'<s>🤖\n{meta_instruction}\n'
|
227 |
for message in messages:
|
228 |
cur_content = message['content']
|
229 |
if message['role'] == 'user':
|
|
|
296 |
|
297 |
if __name__ == '__main__':
|
298 |
main()
|
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
streamlit>=1.8.0
|
2 |
-
transformers==4.
|
3 |
torch>=2.0.0
|
4 |
accelerate>=0.20.0
|
5 |
sentencepiece
|
|
|
1 |
streamlit>=1.8.0
|
2 |
+
transformers==4.39.3
|
3 |
torch>=2.0.0
|
4 |
accelerate>=0.20.0
|
5 |
sentencepiece
|