Spaces:
Runtime error
Runtime error
#!python | |
# -*- coding: utf-8 -*- | |
# @author: Kun | |
import os | |
import torch | |
from flagai.auto_model.auto_loader import AutoLoader | |
from flagai.model.predictor.predictor import Predictor | |
from flagai.model.predictor.aquila import aquila_generate | |
from flagai.data.tokenizer import Tokenizer | |
import bminf | |
max_token: int = 128 # 10000 # 64 | |
temperature: float = 0.75 | |
top_p = 0.9 | |
state_dict = "./checkpoints_in" | |
model_name = 'aquilachat-7b' | |
def load_model(): | |
loader = AutoLoader( | |
"lm", | |
model_dir=state_dict, | |
model_name=model_name, | |
use_cache=True, | |
fp16=True) | |
model = loader.get_model() | |
tokenizer = loader.get_tokenizer() | |
cache_dir = os.path.join(state_dict, model_name) | |
model.eval() | |
with torch.cuda.device(0): | |
model = bminf.wrapper(model, quantization=False, memory_limit=2 << 30) | |
return tokenizer, model |