File size: 1,799 Bytes
dae67e9
 
 
c90fdd5
dae67e9
c90fdd5
aa0f2ec
 
 
dae67e9
 
 
 
 
 
 
 
8b634d0
caaf812
dae67e9
 
 
 
 
 
d753197
 
dae67e9
 
 
 
 
 
 
926b662
dae67e9
 
 
80fb1ef
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig
from peft import PeftModel, PeftConfig
from model import Model

class KoAlpaca(Model):
    def __init__(self,
                name:str='KoAlpaca'):
        self.name = name
        peft_model_id = "4n3mone/Komuchat-koalpaca-polyglot-12.8B"
        config = PeftConfig.from_pretrained(peft_model_id)
        self.bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        #self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map={"":0})
        self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map='auto')
        self.model = PeftModel.from_pretrained(self.model, peft_model_id)
        self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
        self.gen_config = GenerationConfig.from_pretrained('./models/koalpaca', 'gen_config.json')
        self.INPUT_FORMAT = "### 질문: <INPUT>\n\n### 답변:"
        self.model.eval()
        
        super().__init__()
        
    def generate(self, inputs):
        inputs = self.INPUT_FORMAT.replace('<INPUT>', inputs)
        output_ids = self.model.generate(
            **self.tokenizer(
                inputs, 
                return_tensors='pt', 
                return_token_type_ids=False
            ).to(self.model.device), 
            generation_config=self.gen_config
        )
        outputs = self.tokenizer.decode(output_ids[0]).split("### 답변: ")[-1]
        return outputs