File size: 2,439 Bytes
9b6a4ab
 
8d030a2
 
9c042fd
 
 
 
 
89cb869
8d030a2
de222eb
8d030a2
9b6a4ab
360d9e4
8d030a2
360d9e4
89cb869
 
 
 
 
 
 
 
 
 
8d030a2
 
 
89cb869
 
 
 
d7ec399
89cb869
 
 
 
 
 
 
 
 
 
 
 
 
d7ec399
89cb869
60399ca
 
360d9e4
9c042fd
 
 
360d9e4
9c042fd
da75503
9c042fd
 
 
360d9e4
 
 
 
 
 
 
 
 
 
 
 
 
556ee99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Test various models."""
# pylint: disable=invalid-name, line-too-long,broad-exception-caught, protected-access
import os
import time
from pathlib import Path

import torch
from loguru import logger
from transformers import AutoModel, AutoTokenizer

# ruff: noqa: E402
# os.system("pip install --upgrade torch transformers sentencepiece scipy cpm_kernels accelerate bitsandbytes loguru")

# os.system("pip install torch transformers sentencepiece loguru")



# fix timezone in Linux
os.environ["TZ"] = "Asia/Shanghai"
try:
    time.tzset()  # type: ignore # pylint: disable=no-member
except Exception:
    # Windows
    logger.warning("Windows, cant run time.tzset()")

model_name = "THUDM/chatglm2-6b-int4"  # 3.9G

tokenizer = AutoTokenizer.from_pretrained(
    "THUDM/chatglm2-6b-int4", trust_remote_code=True
)

has_cuda = torch.cuda.is_available()
# has_cuda = False  # force cpu

logger.debug("load")
if has_cuda:
    if model_name.endswith("int4"):
        model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
    else:
        model = (
            AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half()
        )
else:
    model = AutoModel.from_pretrained(
        model_name, trust_remote_code=True
    ).half()  # .float() .half().float()

model = model.eval()
logger.debug("done load")

# tokenizer = AutoTokenizer.from_pretrained("openchat/openchat_v2_w")
# model = AutoModelForCausalLM.from_pretrained("openchat/openchat_v2_w", load_in_8bit_fp32_cpu_offload=True, load_in_8bit=True)

# locate model file cache
cache_loc = Path("~/.cache/huggingface/hub").expanduser()
model_cache_path = [elm for elm in Path(cache_loc).rglob("*") if Path(model_name).name in elm.as_posix() and "pytorch_model.bin" in elm.as_posix()]

logger.debug(f"{model_cache_path=}")

if model_cache_path:
    model_size_gb = Path(model_cache_path).stat().st_size / 2**30
    logger.info(f"{model_name=} {model_size_gb=:.2f} GB")

# with gr.Blocks() as demo:
#     chatbot = gr.Chatbot()
#     msg = gr.Textbox()
#     clear = gr.ClearButton([msg, chatbot])

#     def respond(message, chat_history):
#         response, chat_history = model.chat(tokenizer, message, history=chat_history, temperature=0.7, repetition_penalty=1.2, max_length=128)
#         chat_history.append((message, response))
#         return "", chat_history

#     msg.submit(respond, [msg, chatbot], [msg, chatbot])

# demo.launch()