File size: 2,322 Bytes
9b6a4ab
 
8d030a2
 
89cb869
8d030a2
de222eb
8d030a2
9b6a4ab
360d9e4
8d030a2
 
dcea2af
9b6a4ab
8d030a2
360d9e4
89cb869
 
 
 
 
 
 
 
 
 
8d030a2
 
 
89cb869
 
 
 
d7ec399
89cb869
 
 
 
 
 
 
 
 
 
 
 
 
d7ec399
89cb869
60399ca
 
360d9e4
200c2eb
 
 
360d9e4
200c2eb
da75503
90cb722
360d9e4
 
 
 
 
 
 
 
 
 
 
 
 
556ee99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""Test various models."""
# pylint: disable=invalid-name, line-too-long,broad-exception-caught, protected-access
import os
import time

# ruff: noqa: E402
# os.system("pip install --upgrade torch transformers sentencepiece scipy cpm_kernels accelerate bitsandbytes loguru")

# os.system("pip install torch transformers sentencepiece loguru")

from pathlib import Path

import torch
from loguru import logger
from transformers import AutoModel, AutoTokenizer

# fix timezone in Linux
os.environ["TZ"] = "Asia/Shanghai"
try:
    time.tzset()  # type: ignore # pylint: disable=no-member
except Exception:
    # Windows
    logger.warning("Windows, cant run time.tzset()")

model_name = "THUDM/chatglm2-6b-int4"  # 3.9G

tokenizer = AutoTokenizer.from_pretrained(
    "THUDM/chatglm2-6b-int4", trust_remote_code=True
)

has_cuda = torch.cuda.is_available()
# has_cuda = False  # force cpu

logger.debug("load")
if has_cuda:
    if model_name.endswith("int4"):
        model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
    else:
        model = (
            AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half()
        )
else:
    model = AutoModel.from_pretrained(
        model_name, trust_remote_code=True
    ).half()  # .float() .half().float()

model = model.eval()
logger.debug("done load")

# tokenizer = AutoTokenizer.from_pretrained("openchat/openchat_v2_w")
# model = AutoModelForCausalLM.from_pretrained("openchat/openchat_v2_w", load_in_8bit_fp32_cpu_offload=True, load_in_8bit=True)

# model_path = model.config._dict["model_name_or_path"]
# logger.debug(f"{model.config=} {type(model.config)=} {model_path=}")
logger.debug(f"{model.config=}, {type(model.config)=} ")

# model_size_gb = Path(model_path).stat().st_size / 2**30

# logger.info(f"{model_name=} {model_size_gb=:.2f} GB")

# with gr.Blocks() as demo:
#     chatbot = gr.Chatbot()
#     msg = gr.Textbox()
#     clear = gr.ClearButton([msg, chatbot])

#     def respond(message, chat_history):
#         response, chat_history = model.chat(tokenizer, message, history=chat_history, temperature=0.7, repetition_penalty=1.2, max_length=128)
#         chat_history.append((message, response))
#         return "", chat_history

#     msg.submit(respond, [msg, chatbot], [msg, chatbot])

# demo.launch()