qwen-7b-chat / app.py
ffreemt
103cf8f
raw
history blame
1.75 kB
"""Run qwen 7b.
transformers 4.31.0
"""
import os
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from transformers.generation import GenerationConfig
from transformers import BitsAndBytesConfig
from loguru import logger
os.environ["TZ"] = "Asia/Shanghai"
try:
time.tzset() # type: ignore # pylint: disable=no-member
except Exception:
# Windows
logger.warning("Windows, cant run time.tzset()")
device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
# has_cuda = False # force cpu
model_name = "Qwen/Qwen-7B-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# quantization configuration for NF4 (4 bits)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=torch.bfloat16
)
# quantization configuration for Int8 (8 bits)
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device_map,
quantization_config=quantization_config,
# max_memory=max_memory,
trust_remote_code=True,
).eval()
# model = model.eval()
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval()
# Runs
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
# 可指定不同的生成长度、top_p等相关超参
model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
# response, history = model.chat(tokenizer, "你好", history=None)
response, history = model.chat(tokenizer, "你好", history=[])
print(response)