Spaces:
Sleeping
Sleeping
customize the code for peft model
Browse files- src/deepeval/base_task.py +22 -9
src/deepeval/base_task.py
CHANGED
@@ -3,6 +3,7 @@ from datasets import load_dataset
|
|
3 |
import os
|
4 |
from dotenv import load_dotenv
|
5 |
import openai
|
|
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor
|
7 |
import torch
|
8 |
from typing import List
|
@@ -30,20 +31,32 @@ class BaseTask(ABC):
|
|
30 |
return cls._model_cache[model_name]
|
31 |
|
32 |
@staticmethod
|
33 |
-
def load_model(model_name: str, device):
|
34 |
"""Loads model and tokenizer once and caches it."""
|
35 |
print(f"Loading model: {model_name}")
|
36 |
start_time = datetime.now()
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
print(f"Model loaded in {(end_time - start_time).seconds} seconds.")
|
45 |
print("Model loaded.")
|
46 |
-
|
47 |
return model, tokenizer
|
48 |
|
49 |
|
|
|
3 |
import os
|
4 |
from dotenv import load_dotenv
|
5 |
import openai
|
6 |
+
from peft import PeftModel
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor
|
8 |
import torch
|
9 |
from typing import List
|
|
|
31 |
return cls._model_cache[model_name]
|
32 |
|
33 |
@staticmethod
|
34 |
+
def load_model(model_name: str, device, weight, dtype, base_model):
|
35 |
"""Loads model and tokenizer once and caches it."""
|
36 |
print(f"Loading model: {model_name}")
|
37 |
start_time = datetime.now()
|
38 |
+
if weight == "Adapter":
|
39 |
+
base_model_1 = AutoModelForCausalLM.from_pretrained(
|
40 |
+
base_model,
|
41 |
+
torch_dtype=dtype,
|
42 |
+
device_map=device,
|
43 |
+
token=HF_TOKEN, # Replace with actual token
|
44 |
+
)
|
45 |
+
model = PeftModel.from_pretrained(base_model_1, base_model)
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
47 |
+
end_time = datetime.now()
|
48 |
+
else:
|
49 |
+
model = AutoModelForCausalLM.from_pretrained(
|
50 |
+
model_name,
|
51 |
+
torch_dtype=dtype,
|
52 |
+
device_map=device,
|
53 |
+
token=HF_TOKEN, # Replace with actual token
|
54 |
+
)
|
55 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
56 |
+
end_time = datetime.now()
|
57 |
print(f"Model loaded in {(end_time - start_time).seconds} seconds.")
|
58 |
print("Model loaded.")
|
59 |
+
|
60 |
return model, tokenizer
|
61 |
|
62 |
|