Spaces:
Sleeping
Sleeping
Update services/llama_generator.py
Browse files
services/llama_generator.py
CHANGED
@@ -4,6 +4,8 @@ from config.config import GenerationConfig, ModelConfig
|
|
4 |
from typing import List, Dict, Any, Optional, Tuple
|
5 |
from datetime import datetime
|
6 |
import logging
|
|
|
|
|
7 |
from config.config import settings
|
8 |
|
9 |
from services.prompt_builder import LlamaPromptTemplate
|
@@ -50,8 +52,9 @@ class LlamaGenerator(BaseGenerator):
|
|
50 |
print(llama_model_name)
|
51 |
print(prm_model_path)
|
52 |
|
53 |
-
self.model_manager = ModelManager()
|
54 |
-
|
|
|
55 |
self.tokenizer = self.model_manager.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
|
56 |
|
57 |
super().__init__(
|
|
|
4 |
from typing import List, Dict, Any, Optional, Tuple
|
5 |
from datetime import datetime
|
6 |
import logging
|
7 |
+
import pytorch
|
8 |
+
|
9 |
from config.config import settings
|
10 |
|
11 |
from services.prompt_builder import LlamaPromptTemplate
|
|
|
52 |
print(llama_model_name)
|
53 |
print(prm_model_path)
|
54 |
|
55 |
+
self.model_manager = ModelManager()
|
56 |
+
|
57 |
+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
self.tokenizer = self.model_manager.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
|
59 |
|
60 |
super().__init__(
|