KOMUChat / koalpaca.py
ElPlaguister
Feat New Design with Dynamic Tabs
aa0f2ec
raw
history blame
1.76 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig
from peft import PeftModel, PeftConfig
from model import Model
class KoAlpaca(Model):
def __init__(self,
name:str='KoAlpaca'):
self.name = name
peft_model_id = "4n3mone/Komuchat-koalpaca-polyglot-12.8B"
config = PeftConfig.from_pretrained(peft_model_id)
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
#self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map={"":0})
self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map='auto')
self.model = PeftModel.from_pretrained(self.model, peft_model_id)
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
self.gen_config = GenerationConfig.from_pretrained('./models/koalpaca', 'gen_config.json')
self.INPUT_FORMAT = "### 질문: <INPUT>\n\n### 답변:"
self.model.eval()
def generate(self, inputs):
inputs = self.INPUT_FORMAT.replace('<INPUT>', inputs)
output_ids = self.model.generate(
**self.tokenizer(
inputs,
return_tensors='pt',
return_token_type_ids=False
).to(self.model.device),
generation_config=self.gen_config
)
outputs = self.tokenizer.decode(output_ids[0]).split("### 답변: ")[-1]
return outputs