Update services/model_service.py
Browse files
services/model_service.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
# services/model_service.py
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
import torch
|
@@ -28,15 +27,13 @@ class ModelService:
|
|
28 |
# Load tokenizer
|
29 |
self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
|
30 |
|
31 |
-
# Load model configuration
|
32 |
config = LlamaConfig.from_pretrained(settings.MODEL_NAME)
|
33 |
|
|
|
34 |
if hasattr(config, "rope_scaling"):
|
35 |
-
logger.info("
|
36 |
-
config.rope_scaling =
|
37 |
-
"type": "linear", # Ensure the type is valid
|
38 |
-
"factor": 32.0 # Ensure factor is a valid float
|
39 |
-
}
|
40 |
|
41 |
# Load model with the updated configuration
|
42 |
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig
|
2 |
from sentence_transformers import SentenceTransformer
|
3 |
import torch
|
|
|
27 |
# Load tokenizer
|
28 |
self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
|
29 |
|
30 |
+
# Load model configuration without modifying rope_scaling
|
31 |
config = LlamaConfig.from_pretrained(settings.MODEL_NAME)
|
32 |
|
33 |
+
# Remove rope_scaling if present
|
34 |
if hasattr(config, "rope_scaling"):
|
35 |
+
logger.info("Removing rope_scaling from configuration...")
|
36 |
+
config.rope_scaling = None
|
|
|
|
|
|
|
37 |
|
38 |
# Load model with the updated configuration
|
39 |
self.model = AutoModelForCausalLM.from_pretrained(
|