File size: 1,753 Bytes
a27c4fb
 
 
 
7725b42
 
 
 
a27c4fb
7725b42
 
a27c4fb
7725b42
 
 
 
 
 
a27c4fb
7725b42
 
a27c4fb
7725b42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103cf8f
 
 
7725b42
 
 
 
 
a27c4fb
7725b42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Run qwen 7b.

transformers 4.31.0
"""
import os
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from transformers.generation import GenerationConfig
from transformers import BitsAndBytesConfig
from loguru import logger

os.environ["TZ"] = "Asia/Shanghai"
try:
    time.tzset()  # type: ignore # pylint: disable=no-member
except Exception:
    # Windows
    logger.warning("Windows, cant run time.tzset()")

device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
# has_cuda = False  # force cpu

model_name = "Qwen/Qwen-7B-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# quantization configuration for NF4 (4 bits)
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16
)

# quantization configuration for Int8 (8 bits)
quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device_map,
    quantization_config=quantization_config,
    # max_memory=max_memory,
    trust_remote_code=True,
).eval()

# model = model.eval()

# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval()

# Runs
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()

# 可指定不同的生成长度、top_p等相关超参
model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) 

# response, history = model.chat(tokenizer, "你好", history=None)
response, history = model.chat(tokenizer, "你好", history=[])
print(response)