KOMUChat / koalpaca.py
ElPlaguister
Fix KoAlpaca CUDA index
46b206f
raw
history blame
1.8 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig
from peft import PeftModel, PeftConfig
from model import Model
import tensor_parallel as tp
class KoAlpaca(Model):
def __init__(self):
peft_model_id = "4n3mone/Komuchat-koalpaca-polyglot-12.8B"
config = PeftConfig.from_pretrained(peft_model_id)
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
#self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map={"":0})
self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map='auto')
self.model = tp.tensor_parallel(self.model, ["cuda:0", "cuda:1"])
self.model = PeftModel.from_pretrained(self.model, peft_model_id)
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
self.gen_config = GenerationConfig.from_pretrained('./models/koalpaca', 'gen_config.json')
self.INPUT_FORMAT = "### 질문: <INPUT>\n\n### 답변:"
self.model.eval()
def generate(self, inputs):
inputs = self.INPUT_FORMAT.replace('<INPUT>', inputs)
output_ids = self.model.generate(
**self.tokenizer(
inputs,
return_tensors='pt',
return_token_type_ids=False
).to(self.model.device),
generation_config=self.gen_config
)
outputs = self.tokenizer.decode(output_ids[0]).split("### 답변: ")[-1]
return outputs