OG3850 commited on
Commit
b30c279
·
1 Parent(s): 4b8421d

customize the code for peft model

Browse files
Files changed (1) hide show
  1. src/deepeval/base_task.py +22 -9
src/deepeval/base_task.py CHANGED
@@ -3,6 +3,7 @@ from datasets import load_dataset
3
  import os
4
  from dotenv import load_dotenv
5
  import openai
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor
7
  import torch
8
  from typing import List
@@ -30,20 +31,32 @@ class BaseTask(ABC):
30
  return cls._model_cache[model_name]
31
 
32
  @staticmethod
33
- def load_model(model_name: str, device):
34
  """Loads model and tokenizer once and caches it."""
35
  print(f"Loading model: {model_name}")
36
  start_time = datetime.now()
37
- model = AutoModelForCausalLM.from_pretrained(
38
- model_name,
39
- torch_dtype=torch.float16,
40
- device_map=device,
41
- token=HF_TOKEN, # Replace with actual token
42
- )
43
- end_time = datetime.now()
 
 
 
 
 
 
 
 
 
 
 
 
44
  print(f"Model loaded in {(end_time - start_time).seconds} seconds.")
45
  print("Model loaded.")
46
- tokenizer = AutoTokenizer.from_pretrained(model_name)
47
  return model, tokenizer
48
 
49
 
 
3
  import os
4
  from dotenv import load_dotenv
5
  import openai
6
+ from peft import PeftModel
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor
8
  import torch
9
  from typing import List
 
31
  return cls._model_cache[model_name]
32
 
33
  @staticmethod
34
+ def load_model(model_name: str, device, weight, dtype, base_model):
35
  """Loads model and tokenizer once and caches it."""
36
  print(f"Loading model: {model_name}")
37
  start_time = datetime.now()
38
+ if weight == "Adapter":
39
+ base_model_1 = AutoModelForCausalLM.from_pretrained(
40
+ base_model,
41
+ torch_dtype=dtype,
42
+ device_map=device,
43
+ token=HF_TOKEN, # Replace with actual token
44
+ )
45
+ model = PeftModel.from_pretrained(base_model_1, base_model)
46
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
47
+ end_time = datetime.now()
48
+ else:
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_name,
51
+ torch_dtype=dtype,
52
+ device_map=device,
53
+ token=HF_TOKEN, # Replace with actual token
54
+ )
55
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
56
+ end_time = datetime.now()
57
  print(f"Model loaded in {(end_time - start_time).seconds} seconds.")
58
  print("Model loaded.")
59
+
60
  return model, tokenizer
61
 
62