File size: 2,545 Bytes
387a8a0
 
 
fcfe249
387a8a0
 
 
 
 
 
 
 
 
 
 
 
 
 
fcfe249
387a8a0
fcfe249
 
 
 
 
 
387a8a0
 
fcfe249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import keras_nlp

MODEL_NAME = "gemma2_instruct_2b_en"
LORA_WEIGHT_PATH = "ice_breaking_challenge/models/gemma2_it_2b_icebreaking.lora.h5"

def load_model_with_lora(model_name:str = MODEL_NAME, lora_weight_path: str = LORA_WEIGHT_PATH):
    """
    Keras 기반 λͺ¨λΈ λ‘œλ“œ 및 LoRA κ°€μ€‘μΉ˜ 적용

    Args:
        model_name (str): λ‘œλ“œν•  λͺ¨λΈμ˜ 이름
        lora_weight_path (str): μ μš©ν•  LoRA κ°€μ€‘μΉ˜ 파일의 경둜

    Returns:
        keras_nlp.models.GemmaCausalLM: λ‘œλ“œλœ λͺ¨λΈ
    """
    model = keras_nlp.models.GemmaCausalLM.from_preset(model_name)


    model.backbone.load_lora_weights(lora_weight_path)
    question_crawling="λ‚˜μ˜ 이런 점은 일할 λ•Œ 도움이 돼!?"
    answer_crawling="λ‚˜λˆ„κ³  μ‹Άμ–΄ν•˜λŠ” 마음? μ£Όλ³€ μ‚¬λžŒλ“€μ€ 그만 퍼주라고 ν•˜κΈ°λ„ ν•˜μ§€λ§Œ, λ‚΄κ°€ ν΄λΌμ΄μ–ΈνŠΈλ‘œλΆ€ν„° λˆμ„ 벌고자 ν•˜λŠ” 것이 μ•„λ‹ˆλΌ μ‘°κΈˆμ΄λΌλ„ 더 μ±™κ²¨μ£Όκ³ μž ν•˜λŠ” λ§ˆμŒμ„ κ°€μ‘Œμ„ λ•Œ κ²°κ΅­ λ‚˜μ˜ λΈŒλžœλ“œκ°€ 훨씬 더 컀질 수 μžˆλ‹€λŠ” 믿음이 μžˆλ‹€."

    input_text = f"{question_crawling} {answer_crawling}"
    
    print(model.generate(input_text, max_length=512))

    return model


# def template_setting(df:pd.DataFrame, is_test:bool) -> np.ndarray:
#     template_input="""
#     <instruction>
#     Using the text: {question_crawling} {answer_crawling}, create a new multiple-choice question with 4 answer options.
#     """
    
#     template_output="""
#     <Response>
#     {question_generated}
#     {multiple_choice_generated}
#     {answer_generated}
#     """
#     template=template_input+'\n'+template_output
    
#     inputs = np.array(df.apply(lambda row: template.format(
#                         question_crawling=row['question_crawling'],
#                         answer_crawling=row['answer_crawling'],
#                         question_generated=row['question_generated'] if not is_test else "",
#                         multiple_choice_generated=row['multiple_choice_generated'] if not is_test else "",
#                         answer_generated=row['answer_generated'] if not is_test else "").strip(), axis=1))
    
#     outputs = np.array(df.apply(lambda row: template_output.format(
#                         question_generated=row['question_generated'],
#                         multiple_choice_generated=row['multiple_choice_generated'],
#                         answer_generated=row['answer_generated']).strip(), axis=1))
#     combined_array = np.column_stack((inputs, outputs))
#     return combined_array