sitammeur commited on
Commit
79f2bc6
·
verified ·
1 Parent(s): 129e4f6

Update src/app/model.py

Browse files
Files changed (1) hide show
  1. src/app/model.py +57 -45
src/app/model.py CHANGED
@@ -1,45 +1,57 @@
1
- # Necessary imports
2
- import sys
3
- from typing import Any
4
- import torch
5
- from transformers import AutoModel, AutoTokenizer
6
-
7
- # Local imports
8
- from src.logger import logging
9
- from src.exception import CustomExceptionHandling
10
-
11
-
12
- def load_model_and_tokenizer(model_name: str, device: str) -> Any:
13
- """
14
- Load the model and tokenizer.
15
-
16
- Args:
17
- - model_name (str): The name of the model to load.
18
- - device (str): The device to load the model onto.
19
-
20
- Returns:
21
- - model: The loaded model.
22
- - tokenizer: The loaded tokenizer.
23
- """
24
- try:
25
- # Load the model and tokenizer
26
- model = AutoModel.from_pretrained(
27
- model_name,
28
- trust_remote_code=True,
29
- attn_implementation="sdpa",
30
- torch_dtype=torch.bfloat16,
31
- )
32
- model = model.to(device=device)
33
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
34
- model.eval()
35
-
36
- # Log the successful loading of the model and tokenizer
37
- logging.info("Model and tokenizer loaded successfully.")
38
-
39
- # Return the model and tokenizer
40
- return model, tokenizer
41
-
42
- # Handle exceptions that may occur during model and tokenizer loading
43
- except Exception as e:
44
- # Custom exception handling
45
- raise CustomExceptionHandling(e, sys) from e
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Necessary imports
2
+ import os
3
+ import sys
4
+ from dotenv import load_dotenv
5
+ from typing import Any
6
+ import torch
7
+ from transformers import AutoModel, AutoTokenizer
8
+
9
+ # Local imports
10
+ from src.logger import logging
11
+ from src.exception import CustomExceptionHandling
12
+
13
+
14
+ # Load the Environment Variables from .env file
15
+ load_dotenv()
16
+
17
+ # Access token for using the model
18
+ access_token = os.environ.get("ACCESS_TOKEN")
19
+
20
+
21
+ def load_model_and_tokenizer(model_name: str, device: str) -> Any:
22
+ """
23
+ Load the model and tokenizer.
24
+
25
+ Args:
26
+ - model_name (str): The name of the model to load.
27
+ - device (str): The device to load the model onto.
28
+
29
+ Returns:
30
+ - model: The loaded model.
31
+ - tokenizer: The loaded tokenizer.
32
+ """
33
+ try:
34
+ # Load the model and tokenizer
35
+ model = AutoModel.from_pretrained(
36
+ model_name,
37
+ trust_remote_code=True,
38
+ attn_implementation="sdpa",
39
+ torch_dtype=torch.bfloat16,
40
+ token=access_token
41
+ )
42
+ model = model.to(device=device)
43
+ tokenizer = AutoTokenizer.from_pretrained(
44
+ model_name, trust_remote_code=True, token=access_token
45
+ )
46
+ model.eval()
47
+
48
+ # Log the successful loading of the model and tokenizer
49
+ logging.info("Model and tokenizer loaded successfully.")
50
+
51
+ # Return the model and tokenizer
52
+ return model, tokenizer
53
+
54
+ # Handle exceptions that may occur during model and tokenizer loading
55
+ except Exception as e:
56
+ # Custom exception handling
57
+ raise CustomExceptionHandling(e, sys) from e