dnnsdunca commited on
Commit
cd17556
1 Parent(s): 7abc5b6

Update src/model.py

Browse files
Files changed (1) hide show
  1. src/model.py +18 -0
src/model.py CHANGED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from peft import get_peft_model, LoraConfig, TaskType
3
+
4
+ def get_model_and_tokenizer(config):
5
+ model = AutoModelForCausalLM.from_pretrained(config['model']['name'])
6
+ tokenizer = AutoTokenizer.from_pretrained(config['model']['name'])
7
+
8
+ # Add LoRA adapters for fine-tuning
9
+ peft_config = LoraConfig(
10
+ task_type=TaskType.CAUSAL_LM,
11
+ inference_mode=False,
12
+ r=8,
13
+ lora_alpha=32,
14
+ lora_dropout=0.1
15
+ )
16
+ model = get_peft_model(model, peft_config)
17
+
18
+ return model, tokenizer