File size: 2,515 Bytes
0b6d9b3 a90ba7a 6bd1e61 0b6d9b3 a90ba7a 0b6d9b3 a90ba7a 6bd1e61 b64e738 6bd1e61 a90ba7a 0b6d9b3 a90ba7a 2321903 a90ba7a 8bd8f70 a90ba7a 2321903 8bd8f70 0b6d9b3 2321903 0b6d9b3 8bd8f70 2321903 8bd8f70 2321903 d798dc0 324ff2d 2321903 8bd8f70 2321903 a90ba7a f45f68b 0b6d9b3 f45f68b 0b6d9b3 f45f68b 0b6d9b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
# pylint: disable=invalid-name, line-too-long, missing-module-docstring
import gc
import os
import time
import gradio
import rich
import torch
from huggingface_hub import snapshot_download
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits"
if not torch.cuda.is_available():
gradio.Error(f"No cuda, cant run {model_name}")
raise SystemError(f"No cuda, cant run {model_name}")
# snapshot_download?
loc = snapshot_download(repo_id=model_name, local_dir="model")
# fix timezone in Linux
os.environ["TZ"] = "Asia/Shanghai"
try:
time.tzset() # type: ignore # pylint: disable=no-member
except Exception: # pylint: disable=broad-except
# Windows
logger.warning("Windows, cant run time.tzset()")
model = None
gc.collect() # for interactive testing
logger.info("start")
has_cuda = torch.cuda.is_available()
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True)
if has_cuda:
model = AutoModelForCausalLM.from_pretrained(
"model", # loc
device_map="auto",
torch_dtype=torch.bfloat16, # pylint: disable=no-member
load_in_8bit=True,
trust_remote_code=True,
# use_ram_optimized_load=False,
# offload_folder="offload_folder",
) # .cuda()
else:
try:
# model = AutoModel.from_pretrained(model_name, trust_remote_code=True).float()
model = AutoModelForCausalLM.from_pretrained(
# model_name, trust_remote_code=True
"model",
trust_remote_code=True,
) # .float() not supported
except Exception as exc:
logger.error(exc)
logger.warning("Doesnt seem to load for CPU...")
raise SystemExit(1) from exc
model = model.eval()
rich.print(f"{model=}")
logger.info("done")
tokenizer = AutoTokenizer.from_pretrained(
"baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True
)
# model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
model.generation_config = GenerationConfig.from_pretrained(
"baichuan-inc/Baichuan2-13B-Chat-4bits"
)
messages = []
messages.append({"role": "user", "content": "解释一下“温故而知新”"})
response = model.chat(tokenizer, messages)
rich.print(response)
logger.info(f"{response=}")
|