Spaces:
Paused
Paused
File size: 1,701 Bytes
dae67e9 c90fdd5 dae67e9 c90fdd5 dae67e9 8b634d0 caaf812 dae67e9 926b662 dae67e9 80fb1ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig
from peft import PeftModel, PeftConfig
from model import Model
class KoAlpaca(Model):
def __init__(self):
peft_model_id = "4n3mone/Komuchat-koalpaca-polyglot-12.8B"
config = PeftConfig.from_pretrained(peft_model_id)
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
#self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map={"":0})
self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map='auto')
self.model = PeftModel.from_pretrained(self.model, peft_model_id)
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
self.gen_config = GenerationConfig.from_pretrained('./models/koalpaca', 'gen_config.json')
self.INPUT_FORMAT = "### 질문: <INPUT>\n\n### 답변:"
self.model.eval()
def generate(self, inputs):
inputs = self.INPUT_FORMAT.replace('<INPUT>', inputs)
output_ids = self.model.generate(
**self.tokenizer(
inputs,
return_tensors='pt',
return_token_type_ids=False
).to(self.model.device),
generation_config=self.gen_config
)
outputs = self.tokenizer.decode(output_ids[0]).split("### 답변: ")[-1]
return outputs
|