ffreemt
commited on
Commit
·
0b6d9b3
1
Parent(s):
eecb4fb
Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
import os
|
4 |
import time
|
5 |
-
import torch
|
6 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
-
import gc
|
8 |
|
9 |
-
|
|
|
10 |
from huggingface_hub import snapshot_download
|
11 |
-
|
|
|
|
|
12 |
|
13 |
model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits"
|
14 |
# snapshot_download?
|
@@ -18,7 +18,7 @@ loc = snapshot_download(repo_id=model_name, local_dir="model")
|
|
18 |
os.environ["TZ"] = "Asia/Shanghai"
|
19 |
try:
|
20 |
time.tzset() # type: ignore # pylint: disable=no-member
|
21 |
-
except Exception:
|
22 |
# Windows
|
23 |
logger.warning("Windows, cant run time.tzset()")
|
24 |
|
@@ -30,16 +30,19 @@ has_cuda = torch.cuda.is_available()
|
|
30 |
|
31 |
if has_cuda:
|
32 |
model = AutoModelForCausalLM.from_pretrained(
|
33 |
-
"model",
|
34 |
-
# device_map="auto",
|
35 |
-
torch_dtype=torch.bfloat16,
|
36 |
load_in_8bit=True,
|
37 |
trust_remote_code=True,
|
38 |
# use_ram_optimized_load=False,
|
39 |
# offload_folder="offload_folder",
|
40 |
).cuda()
|
41 |
else:
|
42 |
-
model = AutoModel.from_pretrained(model_name, trust_remote_code=True).float()
|
|
|
|
|
|
|
43 |
|
44 |
model = model.eval()
|
45 |
|
@@ -47,18 +50,19 @@ rich.print(f"{model=}")
|
|
47 |
|
48 |
logger.info("done")
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True)
|
54 |
|
55 |
# model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
56 |
|
57 |
-
model.generation_config = GenerationConfig.from_pretrained(
|
|
|
|
|
58 |
messages = []
|
59 |
messages.append({"role": "user", "content": "解释一下“温故而知新”"})
|
60 |
response = model.chat(tokenizer, messages)
|
61 |
|
62 |
rich.print(response)
|
63 |
|
64 |
-
logger.info(f"{response=}")
|
|
|
1 |
+
# pylint: disable=invalid-name, line-too-long, missing-module-docstring
|
2 |
+
import gc
|
3 |
import os
|
4 |
import time
|
|
|
|
|
|
|
5 |
|
6 |
+
import rich
|
7 |
+
import torch
|
8 |
from huggingface_hub import snapshot_download
|
9 |
+
from loguru import logger
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
+
from transformers.generation.utils import GenerationConfig
|
12 |
|
13 |
model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits"
|
14 |
# snapshot_download?
|
|
|
18 |
os.environ["TZ"] = "Asia/Shanghai"
|
19 |
try:
|
20 |
time.tzset() # type: ignore # pylint: disable=no-member
|
21 |
+
except Exception: # pylint: disable=broad-except
|
22 |
# Windows
|
23 |
logger.warning("Windows, cant run time.tzset()")
|
24 |
|
|
|
30 |
|
31 |
if has_cuda:
|
32 |
model = AutoModelForCausalLM.from_pretrained(
|
33 |
+
"model", # loc
|
34 |
+
# device_map="auto",
|
35 |
+
torch_dtype=torch.bfloat16, # pylint: disable=no-member
|
36 |
load_in_8bit=True,
|
37 |
trust_remote_code=True,
|
38 |
# use_ram_optimized_load=False,
|
39 |
# offload_folder="offload_folder",
|
40 |
).cuda()
|
41 |
else:
|
42 |
+
# model = AutoModel.from_pretrained(model_name, trust_remote_code=True).float()
|
43 |
+
model = AutoModelForCausalLM.from_pretrained(
|
44 |
+
model_name, trust_remote_code=True
|
45 |
+
).float()
|
46 |
|
47 |
model = model.eval()
|
48 |
|
|
|
50 |
|
51 |
logger.info("done")
|
52 |
|
53 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
54 |
+
"baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True
|
55 |
+
)
|
|
|
56 |
|
57 |
# model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
58 |
|
59 |
+
model.generation_config = GenerationConfig.from_pretrained(
|
60 |
+
"baichuan-inc/Baichuan2-13B-Chat-4bits"
|
61 |
+
)
|
62 |
messages = []
|
63 |
messages.append({"role": "user", "content": "解释一下“温故而知新”"})
|
64 |
response = model.chat(tokenizer, messages)
|
65 |
|
66 |
rich.print(response)
|
67 |
|
68 |
+
logger.info(f"{response=}")
|