Spaces:
Runtime error
Runtime error
File size: 883 Bytes
591004d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
#!python
# -*- coding: utf-8 -*-
# @author: Kun
import os
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
from flagai.model.predictor.aquila import aquila_generate
from flagai.data.tokenizer import Tokenizer
import bminf
max_token: int = 128 # 10000 # 64
temperature: float = 0.75
top_p = 0.9
state_dict = "./checkpoints_in"
model_name = 'aquilachat-7b'
def load_model():
loader = AutoLoader(
"lm",
model_dir=state_dict,
model_name=model_name,
use_cache=True,
fp16=True)
model = loader.get_model()
tokenizer = loader.get_tokenizer()
cache_dir = os.path.join(state_dict, model_name)
model.eval()
with torch.cuda.device(0):
model = bminf.wrapper(model, quantization=False, memory_limit=2 << 30)
return tokenizer, model |