joeykims commited on
Commit
982b2c3
ยท
0 Parent(s):

Feat: gemma coffee model

Browse files
Files changed (2) hide show
  1. gemma_Ko_coffee.py +93 -0
  2. gemma_Ko_coffee_load_model.py +46 -0
gemma_Ko_coffee.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # import os
3
+ from datasets import Dataset
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TrainingArguments
5
+ from peft import LoraConfig, PeftModel
6
+ from trl import SFTTrainer
7
+ import warnings
8
+
9
+ # ํŠน์ • ๊ฒฝ๊ณ ๋ฅผ ๋ฌด์‹œํ•˜๋„๋ก ์„ค์ •
10
+ warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub.file_download")
11
+
12
+ # ๊ธฐ๊ธฐ ์„ค์ • (CPU ๋˜๋Š” MPS)
13
+ device = torch.device("cpu")
14
+
15
+ # ๋ฐ์ดํ„ฐ ๋กœ๋“œ
16
+ data_path = "./data_finetunned/coffee_finetuning_20240914_witi_total.jsonl"
17
+ dataset = Dataset.from_json(data_path)
18
+
19
+ print("๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ์™„๋ฃŒ")
20
+
21
+ # ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
22
+ BASE_MODEL = "beomi/gemma-ko-2b"
23
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
24
+ model.to(device)
25
+
26
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)
27
+
28
+
29
+ # Prompt ์ƒ์„ฑ ํ•จ์ˆ˜ ์ˆ˜์ • (instruction๊ณผ context ๊ธฐ๋ฐ˜์œผ๋กœ ์ƒ์„ฑ)
30
+ def generate_prompt(example):
31
+ prompt_list = []
32
+ for i in range(len(example['instruction'])):
33
+ prompt_list.append(f"""<bos><start_of_turn>user
34
+ {example['instruction'][i]}<end_of_turn>
35
+ <start_of_turn>model
36
+ {example['response'][i]}<end_of_turn><eos>""")
37
+ return prompt_list
38
+
39
+ # ๋ฐ์ดํ„ฐ์…‹์„ train ๋ฐ์ดํ„ฐ๋กœ ์„ค์ •
40
+ train_data = dataset
41
+
42
+ # ์ฒซ ๋ฒˆ์งธ ๋ฐ์ดํ„ฐ์˜ ํ”„๋กฌํ”„ํŠธ ํ™•์ธ
43
+ print(generate_prompt(train_data[:1])[0])
44
+
45
+ # LoRA ์„ค์ •
46
+ lora_config = LoraConfig(
47
+ r=6,
48
+ lora_alpha=8,
49
+ lora_dropout=0.05,
50
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
51
+ task_type="CAUSAL_LM",
52
+ )
53
+
54
+ # ๋ชจ๋ธ ์„ค์ •
55
+ trainer = SFTTrainer(
56
+ model=model,
57
+ train_dataset=train_data,
58
+ max_seq_length=512,
59
+ args=TrainingArguments(
60
+ output_dir="outputs",
61
+ max_steps=3000,
62
+ per_device_train_batch_size=1,
63
+ gradient_accumulation_steps=4,
64
+ optim="adamw_torch",
65
+ warmup_steps=0.03,
66
+ learning_rate=2e-4,
67
+ fp16=False,
68
+ logging_steps=100,
69
+ push_to_hub=False,
70
+ report_to='none',
71
+ use_mps_device=False # CPU๋กœ ์„ค์ •
72
+ ),
73
+ peft_config=lora_config,
74
+ formatting_func=generate_prompt, # ์ƒˆ๋กœ์šด ํฌ๋งทํŒ… ํ•จ์ˆ˜ ์ ์šฉ
75
+ )
76
+
77
+ # ํ›ˆ๋ จ ์‹œ์ž‘
78
+ trainer.train()
79
+
80
+ # ์–ด๋Œ‘ํ„ฐ ๋ชจ๋ธ ์ €์žฅ
81
+ ADAPTER_MODEL = "lora_adapter"
82
+ trainer.model.save_pretrained(ADAPTER_MODEL)
83
+
84
+ # ์ตœ์ข… ๋ชจ๋ธ ๋ณ‘ํ•ฉ ๋ฐ ์ €์žฅ
85
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto', torch_dtype=torch.float16)
86
+ model = PeftModel.from_pretrained(model, ADAPTER_MODEL, device_map='auto', torch_dtype=torch.float16)
87
+
88
+ model = model.merge_and_unload()
89
+ model.save_pretrained('./gemma_outputs/gemma-ko-2b-beans-20240915-01')
90
+ print("๋ชจ๋ธ ์ €์žฅ ์™„๋ฃŒ")
91
+ # ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
92
+ tokenizer.save_pretrained('./gemma_outputs/gemma-ko-2b-beans-20240915-01')
93
+ print("tokenizer ์ €์žฅ ์™„๋ฃŒ")
gemma_Ko_coffee_load_model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
+
3
+ # ๋ฒ ์ด์Šค ๋ชจ๋ธ์—์„œ ํ† ํฌ๋‚˜์ด์ € ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
4
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
5
+
6
+ # ์ €์žฅํ•œ ๋ชจ๋ธ ๊ฒฝ๋กœ
7
+ model_dir = './gemma_outputs/gemma-2b-it-sum-ko-beans-1'
8
+ model = AutoModelForCausalLM.from_pretrained(model_dir)
9
+ # tokenizer = AutoTokenizer.from_pretrained(model_dir)
10
+
11
+ # ๋ชจ๋ธ์„ CPU๋กœ ์ด๋™ (๋งŒ์•ฝ GPU๋ฅผ ์“ด๋‹ค๋ฉด 'cuda'๋กœ ๋ฐ”๊ฟ”์ค˜)
12
+ model.to("cpu") #cpu
13
+
14
+ conversation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
15
+
16
+
17
+ def chat_with_model(input_text):
18
+ # ๋Œ€ํ™”์šฉ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ƒ์„ฑ
19
+ messages = [{"role": "user", "content": input_text}]
20
+
21
+ # ํ† ํฌ๋‚˜์ด์ €๋กœ ์ž…๋ ฅ์„ ํ”„๋กฌํ”„ํŠธ ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜
22
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
23
+
24
+ # ๋ชจ๋ธ์ด ์‘๋‹ต์„ ์ƒ์„ฑ
25
+ # response = conversation_pipeline(prompt, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
26
+ response = conversation_pipeline(prompt, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, add_special_tokens=True)
27
+
28
+ # ๋ชจ๋ธ์˜ ์ƒ์„ฑ๋œ ์‘๋‹ต ์ถ”์ถœ
29
+ generated_text = response[0]["generated_text"]
30
+ model_response = generated_text[len(prompt):] # ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ œ๊ฑฐํ•˜๊ณ  ์‘๋‹ต๋งŒ ๋ฐ˜ํ™˜
31
+ return model_response
32
+
33
+
34
+ # ๋Œ€ํ™”๋ฅผ ๊ณ„์† ์ด์–ด๋‚˜๊ฐˆ ์ˆ˜ ์žˆ๋Š” ๊ตฌ์กฐ
35
+ def interactive_chat():
36
+ print("๋Œ€ํ™”ํ˜• ๋ชจ๋“œ์— ์˜ค์‹  ๊ฒƒ์„ ํ™˜์˜ํ•ฉ๋‹ˆ๋‹ค! '์ข…๋ฃŒ'๋ผ๊ณ  ์ž…๋ ฅํ•˜๋ฉด ๋Œ€ํ™”๊ฐ€ ์ข…๋ฃŒ๋ฉ๋‹ˆ๋‹ค.")
37
+ while True:
38
+ user_input = input("์‚ฌ์šฉ์ž: ") # ์‚ฌ์šฉ์ž ์ž…๋ ฅ ๋ฐ›๊ธฐ
39
+ if user_input.lower() == "์ข…๋ฃŒ": # '์ข…๋ฃŒ'๋ผ๊ณ  ์ž…๋ ฅํ•˜๋ฉด ๋Œ€ํ™” ์ข…๋ฃŒ
40
+ print("๋Œ€ํ™”๋ฅผ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.")
41
+ break
42
+ model_reply = chat_with_model(user_input) # ๋ชจ๋ธ์˜ ์‘๋‹ต ๋ฐ›๊ธฐ
43
+ print(f"๋ชจ๋ธ: {model_reply}") # ๋ชจ๋ธ์˜ ์‘๋‹ต ์ถœ๋ ฅ
44
+
45
+ # ๋Œ€ํ™” ์‹œ์ž‘
46
+ interactive_chat()