himel06 commited on
Commit
6f95d6d
·
verified ·
1 Parent(s): d9b0e2e

Update BanglaRAG/bangla_rag_pipeline.py

Browse files
Files changed (1) hide show
  1. BanglaRAG/bangla_rag_pipeline.py +18 -61
BanglaRAG/bangla_rag_pipeline.py CHANGED
@@ -13,100 +13,57 @@ from langchain_community.vectorstores import Chroma
13
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
14
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
15
  from langchain_core.output_parsers import StrOutputParser
 
16
  import warnings
17
 
18
  warnings.filterwarnings("ignore")
19
 
20
  class BanglaRAGChain:
21
  def __init__(self):
22
- self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- self.chat_model_id = None
24
- self.embed_model_id = None
25
- self.k = 4
26
- self.max_new_tokens = 1024
27
- self.chunk_size = 500
28
- self.chunk_overlap = 150
29
- self.text_path = ""
30
- self.quantization = None
31
- self.temperature = 0.9
32
- self.top_p = 0.6
33
- self.top_k = 50
34
- self._text_content = None
35
- self.hf_token = None
36
-
37
- self.tokenizer = None
38
- self.chat_model = None
39
- self._llm = None
40
- self._retriever = None
41
- self._db = None
42
- self._documents = []
43
- self._chain = None
44
-
45
- def load(
46
- self,
47
- chat_model_id,
48
- embed_model_id,
49
- text_path,
50
- quantization,
51
- k=4,
52
- top_k=2,
53
- top_p=0.6,
54
- max_new_tokens=1024,
55
- temperature=0.6,
56
- chunk_size=500,
57
- chunk_overlap=150,
58
- hf_token=None,
59
- ):
60
  self.chat_model_id = chat_model_id
61
  self.embed_model_id = embed_model_id
 
62
  self.k = k
63
  self.top_k = top_k
64
  self.top_p = top_p
65
  self.temperature = temperature
66
  self.chunk_size = chunk_size
67
  self.chunk_overlap = chunk_overlap
68
- self.text_path = text_path
69
- self.quantization = quantization
70
- self.max_new_tokens = max_new_tokens
71
  self.hf_token = hf_token
 
 
 
72
 
73
- if self.hf_token is not None:
74
- os.environ["HF_TOKEN"] = str(self.hf_token)
75
-
76
  self._load_models()
77
- self._create_document()
78
- self._update_chroma_db()
79
- self._get_retriever()
80
- self._get_llm()
81
- self._create_chain()
82
 
83
  def _load_models(self):
84
  try:
85
- self.tokenizer = AutoTokenizer.from_pretrained(self.chat_model_id)
86
- bnb_config = None
87
  if self.quantization:
88
- bnb_config = BitsAndBytesConfig(
89
- load_in_4bit=True,
90
- bnb_4bit_use_double_quant=True,
91
- bnb_4bit_quant_type="nf4",
92
- bnb_4bit_compute_dtype=torch.float16,
93
- )
94
  self.chat_model = AutoModelForCausalLM.from_pretrained(
95
  self.chat_model_id,
96
- load_in_8bit=True,
97
- torch_dtype=torch.bfloat16,
98
  device_map="auto",
99
- quantization_config=bnb_config,
 
100
  )
101
  else:
102
  self.chat_model = AutoModelForCausalLM.from_pretrained(
103
  self.chat_model_id,
104
- torch_dtype=torch.bfloat16,
105
  device_map="auto",
 
106
  )
 
 
 
107
  except Exception as e:
108
  raise RuntimeError(f"Error loading chat model: {e}")
109
 
 
110
  def _create_document(self):
111
  try:
112
  with open(self.text_path, "r", encoding="utf-8") as file:
 
13
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
14
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
15
  from langchain_core.output_parsers import StrOutputParser
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
  import warnings
18
 
19
  warnings.filterwarnings("ignore")
20
 
21
  class BanglaRAGChain:
22
  def __init__(self):
23
+ # Initialization code...
24
+ pass
25
+
26
+ def load(self, chat_model_id, embed_model_id, text_path, k, top_k, top_p, temperature, chunk_size, chunk_overlap, hf_token, max_new_tokens, quantization, offload_dir=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  self.chat_model_id = chat_model_id
28
  self.embed_model_id = embed_model_id
29
+ self.text_path = text_path
30
  self.k = k
31
  self.top_k = top_k
32
  self.top_p = top_p
33
  self.temperature = temperature
34
  self.chunk_size = chunk_size
35
  self.chunk_overlap = chunk_overlap
 
 
 
36
  self.hf_token = hf_token
37
+ self.max_new_tokens = max_new_tokens
38
+ self.quantization = quantization
39
+ self.offload_dir = offload_dir # New parameter
40
 
41
+ # Load models
 
 
42
  self._load_models()
 
 
 
 
 
43
 
44
  def _load_models(self):
45
  try:
 
 
46
  if self.quantization:
 
 
 
 
 
 
47
  self.chat_model = AutoModelForCausalLM.from_pretrained(
48
  self.chat_model_id,
49
+ torch_dtype="auto",
 
50
  device_map="auto",
51
+ load_in_4bit=True,
52
+ offload_folder=self.offload_dir, # Offload here
53
  )
54
  else:
55
  self.chat_model = AutoModelForCausalLM.from_pretrained(
56
  self.chat_model_id,
 
57
  device_map="auto",
58
+ offload_folder=self.offload_dir, # Offload here
59
  )
60
+
61
+ self.tokenizer = AutoTokenizer.from_pretrained(self.chat_model_id)
62
+
63
  except Exception as e:
64
  raise RuntimeError(f"Error loading chat model: {e}")
65
 
66
+
67
  def _create_document(self):
68
  try:
69
  with open(self.text_path, "r", encoding="utf-8") as file: