Commit
ยท
982b2c3
0
Parent(s):
Feat: gemma coffee model
Browse files- gemma_Ko_coffee.py +93 -0
- gemma_Ko_coffee_load_model.py +46 -0
gemma_Ko_coffee.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
# import os
|
3 |
+
from datasets import Dataset
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TrainingArguments
|
5 |
+
from peft import LoraConfig, PeftModel
|
6 |
+
from trl import SFTTrainer
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
# ํน์ ๊ฒฝ๊ณ ๋ฅผ ๋ฌด์ํ๋๋ก ์ค์
|
10 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub.file_download")
|
11 |
+
|
12 |
+
# ๊ธฐ๊ธฐ ์ค์ (CPU ๋๋ MPS)
|
13 |
+
device = torch.device("cpu")
|
14 |
+
|
15 |
+
# ๋ฐ์ดํฐ ๋ก๋
|
16 |
+
data_path = "./data_finetunned/coffee_finetuning_20240914_witi_total.jsonl"
|
17 |
+
dataset = Dataset.from_json(data_path)
|
18 |
+
|
19 |
+
print("๋ฐ์ดํฐ์
๋ก๋ ์๋ฃ")
|
20 |
+
|
21 |
+
# ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๋ก๋
|
22 |
+
BASE_MODEL = "beomi/gemma-ko-2b"
|
23 |
+
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
|
24 |
+
model.to(device)
|
25 |
+
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)
|
27 |
+
|
28 |
+
|
29 |
+
# Prompt ์์ฑ ํจ์ ์์ (instruction๊ณผ context ๊ธฐ๋ฐ์ผ๋ก ์์ฑ)
|
30 |
+
def generate_prompt(example):
|
31 |
+
prompt_list = []
|
32 |
+
for i in range(len(example['instruction'])):
|
33 |
+
prompt_list.append(f"""<bos><start_of_turn>user
|
34 |
+
{example['instruction'][i]}<end_of_turn>
|
35 |
+
<start_of_turn>model
|
36 |
+
{example['response'][i]}<end_of_turn><eos>""")
|
37 |
+
return prompt_list
|
38 |
+
|
39 |
+
# ๋ฐ์ดํฐ์
์ train ๋ฐ์ดํฐ๋ก ์ค์
|
40 |
+
train_data = dataset
|
41 |
+
|
42 |
+
# ์ฒซ ๋ฒ์งธ ๋ฐ์ดํฐ์ ํ๋กฌํํธ ํ์ธ
|
43 |
+
print(generate_prompt(train_data[:1])[0])
|
44 |
+
|
45 |
+
# LoRA ์ค์
|
46 |
+
lora_config = LoraConfig(
|
47 |
+
r=6,
|
48 |
+
lora_alpha=8,
|
49 |
+
lora_dropout=0.05,
|
50 |
+
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
|
51 |
+
task_type="CAUSAL_LM",
|
52 |
+
)
|
53 |
+
|
54 |
+
# ๋ชจ๋ธ ์ค์
|
55 |
+
trainer = SFTTrainer(
|
56 |
+
model=model,
|
57 |
+
train_dataset=train_data,
|
58 |
+
max_seq_length=512,
|
59 |
+
args=TrainingArguments(
|
60 |
+
output_dir="outputs",
|
61 |
+
max_steps=3000,
|
62 |
+
per_device_train_batch_size=1,
|
63 |
+
gradient_accumulation_steps=4,
|
64 |
+
optim="adamw_torch",
|
65 |
+
warmup_steps=0.03,
|
66 |
+
learning_rate=2e-4,
|
67 |
+
fp16=False,
|
68 |
+
logging_steps=100,
|
69 |
+
push_to_hub=False,
|
70 |
+
report_to='none',
|
71 |
+
use_mps_device=False # CPU๋ก ์ค์
|
72 |
+
),
|
73 |
+
peft_config=lora_config,
|
74 |
+
formatting_func=generate_prompt, # ์๋ก์ด ํฌ๋งทํ
ํจ์ ์ ์ฉ
|
75 |
+
)
|
76 |
+
|
77 |
+
# ํ๋ จ ์์
|
78 |
+
trainer.train()
|
79 |
+
|
80 |
+
# ์ด๋ํฐ ๋ชจ๋ธ ์ ์ฅ
|
81 |
+
ADAPTER_MODEL = "lora_adapter"
|
82 |
+
trainer.model.save_pretrained(ADAPTER_MODEL)
|
83 |
+
|
84 |
+
# ์ต์ข
๋ชจ๋ธ ๋ณํฉ ๋ฐ ์ ์ฅ
|
85 |
+
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto', torch_dtype=torch.float16)
|
86 |
+
model = PeftModel.from_pretrained(model, ADAPTER_MODEL, device_map='auto', torch_dtype=torch.float16)
|
87 |
+
|
88 |
+
model = model.merge_and_unload()
|
89 |
+
model.save_pretrained('./gemma_outputs/gemma-ko-2b-beans-20240915-01')
|
90 |
+
print("๋ชจ๋ธ ์ ์ฅ ์๋ฃ")
|
91 |
+
# ํ ํฌ๋์ด์ ๋ฅผ ์ ์ฅํฉ๋๋ค.
|
92 |
+
tokenizer.save_pretrained('./gemma_outputs/gemma-ko-2b-beans-20240915-01')
|
93 |
+
print("tokenizer ์ ์ฅ ์๋ฃ")
|
gemma_Ko_coffee_load_model.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
2 |
+
|
3 |
+
# ๋ฒ ์ด์ค ๋ชจ๋ธ์์ ํ ํฌ๋์ด์ ๋ถ๋ฌ์ค๊ธฐ
|
4 |
+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
|
5 |
+
|
6 |
+
# ์ ์ฅํ ๋ชจ๋ธ ๊ฒฝ๋ก
|
7 |
+
model_dir = './gemma_outputs/gemma-2b-it-sum-ko-beans-1'
|
8 |
+
model = AutoModelForCausalLM.from_pretrained(model_dir)
|
9 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
10 |
+
|
11 |
+
# ๋ชจ๋ธ์ CPU๋ก ์ด๋ (๋ง์ฝ GPU๋ฅผ ์ด๋ค๋ฉด 'cuda'๋ก ๋ฐ๊ฟ์ค)
|
12 |
+
model.to("cpu") #cpu
|
13 |
+
|
14 |
+
conversation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
|
15 |
+
|
16 |
+
|
17 |
+
def chat_with_model(input_text):
|
18 |
+
# ๋ํ์ฉ ํ๋กฌํํธ๋ฅผ ์์ฑ
|
19 |
+
messages = [{"role": "user", "content": input_text}]
|
20 |
+
|
21 |
+
# ํ ํฌ๋์ด์ ๋ก ์
๋ ฅ์ ํ๋กฌํํธ ํํ๋ก ๋ณํ
|
22 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
23 |
+
|
24 |
+
# ๋ชจ๋ธ์ด ์๋ต์ ์์ฑ
|
25 |
+
# response = conversation_pipeline(prompt, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
|
26 |
+
response = conversation_pipeline(prompt, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, add_special_tokens=True)
|
27 |
+
|
28 |
+
# ๋ชจ๋ธ์ ์์ฑ๋ ์๋ต ์ถ์ถ
|
29 |
+
generated_text = response[0]["generated_text"]
|
30 |
+
model_response = generated_text[len(prompt):] # ์
๋ ฅ ํ๋กฌํํธ๋ฅผ ์ ๊ฑฐํ๊ณ ์๋ต๋ง ๋ฐํ
|
31 |
+
return model_response
|
32 |
+
|
33 |
+
|
34 |
+
# ๋ํ๋ฅผ ๊ณ์ ์ด์ด๋๊ฐ ์ ์๋ ๊ตฌ์กฐ
|
35 |
+
def interactive_chat():
|
36 |
+
print("๋ํํ ๋ชจ๋์ ์ค์ ๊ฒ์ ํ์ํฉ๋๋ค! '์ข
๋ฃ'๋ผ๊ณ ์
๋ ฅํ๋ฉด ๋ํ๊ฐ ์ข
๋ฃ๋ฉ๋๋ค.")
|
37 |
+
while True:
|
38 |
+
user_input = input("์ฌ์ฉ์: ") # ์ฌ์ฉ์ ์
๋ ฅ ๋ฐ๊ธฐ
|
39 |
+
if user_input.lower() == "์ข
๋ฃ": # '์ข
๋ฃ'๋ผ๊ณ ์
๋ ฅํ๋ฉด ๋ํ ์ข
๋ฃ
|
40 |
+
print("๋ํ๋ฅผ ์ข
๋ฃํฉ๋๋ค.")
|
41 |
+
break
|
42 |
+
model_reply = chat_with_model(user_input) # ๋ชจ๋ธ์ ์๋ต ๋ฐ๊ธฐ
|
43 |
+
print(f"๋ชจ๋ธ: {model_reply}") # ๋ชจ๋ธ์ ์๋ต ์ถ๋ ฅ
|
44 |
+
|
45 |
+
# ๋ํ ์์
|
46 |
+
interactive_chat()
|