ffreemt commited on
Commit
0b6d9b3
·
1 Parent(s): eecb4fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -18
app.py CHANGED
@@ -1,14 +1,14 @@
1
- from loguru import logger
2
- import rich
3
  import os
4
  import time
5
- import torch
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
- import gc
8
 
9
- from huggingface_hub import hf_hub_download
 
10
  from huggingface_hub import snapshot_download
11
- # snapshot_download(repo_id="lysandre/arxiv-nlp")
 
 
12
 
13
  model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits"
14
  # snapshot_download?
@@ -18,7 +18,7 @@ loc = snapshot_download(repo_id=model_name, local_dir="model")
18
  os.environ["TZ"] = "Asia/Shanghai"
19
  try:
20
  time.tzset() # type: ignore # pylint: disable=no-member
21
- except Exception:
22
  # Windows
23
  logger.warning("Windows, cant run time.tzset()")
24
 
@@ -30,16 +30,19 @@ has_cuda = torch.cuda.is_available()
30
 
31
  if has_cuda:
32
  model = AutoModelForCausalLM.from_pretrained(
33
- "model", # loc
34
- # device_map="auto",
35
- torch_dtype=torch.bfloat16,
36
  load_in_8bit=True,
37
  trust_remote_code=True,
38
  # use_ram_optimized_load=False,
39
  # offload_folder="offload_folder",
40
  ).cuda()
41
  else:
42
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True).float()
 
 
 
43
 
44
  model = model.eval()
45
 
@@ -47,18 +50,19 @@ rich.print(f"{model=}")
47
 
48
  logger.info("done")
49
 
50
- # ========
51
- from transformers import AutoModelForCausalLM, AutoTokenizer
52
- from transformers.generation.utils import GenerationConfig
53
- tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True)
54
 
55
  # model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
56
 
57
- model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits")
 
 
58
  messages = []
59
  messages.append({"role": "user", "content": "解释一下“温故而知新”"})
60
  response = model.chat(tokenizer, messages)
61
 
62
  rich.print(response)
63
 
64
- logger.info(f"{response=}")
 
1
+ # pylint: disable=invalid-name, line-too-long, missing-module-docstring
2
+ import gc
3
  import os
4
  import time
 
 
 
5
 
6
+ import rich
7
+ import torch
8
  from huggingface_hub import snapshot_download
9
+ from loguru import logger
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from transformers.generation.utils import GenerationConfig
12
 
13
  model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits"
14
  # snapshot_download?
 
18
  os.environ["TZ"] = "Asia/Shanghai"
19
  try:
20
  time.tzset() # type: ignore # pylint: disable=no-member
21
+ except Exception: # pylint: disable=broad-except
22
  # Windows
23
  logger.warning("Windows, cant run time.tzset()")
24
 
 
30
 
31
  if has_cuda:
32
  model = AutoModelForCausalLM.from_pretrained(
33
+ "model", # loc
34
+ # device_map="auto",
35
+ torch_dtype=torch.bfloat16, # pylint: disable=no-member
36
  load_in_8bit=True,
37
  trust_remote_code=True,
38
  # use_ram_optimized_load=False,
39
  # offload_folder="offload_folder",
40
  ).cuda()
41
  else:
42
+ # model = AutoModel.from_pretrained(model_name, trust_remote_code=True).float()
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ model_name, trust_remote_code=True
45
+ ).float()
46
 
47
  model = model.eval()
48
 
 
50
 
51
  logger.info("done")
52
 
53
+ tokenizer = AutoTokenizer.from_pretrained(
54
+ "baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True
55
+ )
 
56
 
57
  # model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
58
 
59
+ model.generation_config = GenerationConfig.from_pretrained(
60
+ "baichuan-inc/Baichuan2-13B-Chat-4bits"
61
+ )
62
  messages = []
63
  messages.append({"role": "user", "content": "解释一下“温故而知新”"})
64
  response = model.chat(tokenizer, messages)
65
 
66
  rich.print(response)
67
 
68
+ logger.info(f"{response=}")