Update services/model_service.py
Browse files
services/model_service.py
CHANGED
@@ -34,6 +34,7 @@ class ModelService:
|
|
34 |
|
35 |
# Load model with the updated configuration
|
36 |
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
37 |
settings.MODEL_NAME,
|
38 |
torch_dtype=torch.float16 if settings.DEVICE == "cuda" else torch.float32,
|
39 |
device_map="auto" if settings.DEVICE == "cuda" else None
|
|
|
34 |
|
35 |
# Load model with the updated configuration
|
36 |
self.model = AutoModelForCausalLM.from_pretrained(
|
37 |
+
model_type == "llama" ,
|
38 |
settings.MODEL_NAME,
|
39 |
torch_dtype=torch.float16 if settings.DEVICE == "cuda" else torch.float32,
|
40 |
device_map="auto" if settings.DEVICE == "cuda" else None
|