rishi002 commited on
Commit
930a98a
Β·
verified Β·
1 Parent(s): d7ea88b

Update embeddings.py

Browse files
Files changed (1) hide show
  1. embeddings.py +16 -126
embeddings.py CHANGED
@@ -3,161 +3,51 @@ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
  from langchain_community.vectorstores import FAISS
6
- from langchain.chains import RetrievalQA
7
- from langchain_core.prompts import PromptTemplate
8
- from langchain.llms.base import LLM
9
- from typing import Optional, List
10
- import google.generativeai as genai
11
 
12
  # Set Paths
13
  DATA_PATH = "dataFolder/"
14
  DB_FAISS_PATH = "/tmp/vectorstore/db_faiss"
15
 
16
- # Google AI API setup
17
- GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
18
- if not GOOGLE_API_KEY:
19
- raise ValueError("GOOGLE_API_KEY environment variable is required!")
20
-
21
- genai.configure(api_key=GOOGLE_API_KEY)
22
-
23
- # Custom Gemini LLM wrapper for LangChain
24
- class GeminiLLM(LLM):
25
- def __init__(self, model_name="gemini-2.0-flash"):
26
- self.model = genai.GenerativeModel(model_name)
27
-
28
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
29
- try:
30
- response = self.model.generate_content(prompt)
31
- return response.text
32
- except Exception as e:
33
- return f"Error generating response: {str(e)}"
34
-
35
- @property
36
- def _identifying_params(self):
37
- return {"name": "gemini-flash"}
38
-
39
- @property
40
- def _llm_type(self):
41
- return "gemini"
42
-
43
  # Step 1: Load PDF Files
44
- def load_pdf_files(data_path):
 
45
  loader = DirectoryLoader(data_path, glob="*.pdf", loader_cls=PyPDFLoader)
46
  documents = loader.load()
 
47
  return documents
48
 
49
  # Step 2: Create Chunks
50
  def create_chunks(documents):
 
51
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
52
  text_chunks = text_splitter.split_documents(documents)
 
53
  return text_chunks
54
 
55
  # Step 3: Generate Embeddings
56
  def get_embedding_model():
 
57
  CACHE_DIR = "/tmp/models_cache"
58
  os.makedirs(CACHE_DIR, exist_ok=True)
59
 
60
  embedding_model = HuggingFaceEmbeddings(
61
  model_name="rishi002/all-MiniLM-L6-v2",
62
- cache_folder="/tmp/models_cache"
63
  )
 
64
  return embedding_model
65
 
66
  # Step 4: Store Embeddings in FAISS
67
- def store_embeddings(text_chunks, embedding_model, db_path):
 
68
  db = FAISS.from_documents(text_chunks, embedding_model)
69
  db.save_local(db_path)
 
70
  return db
71
 
72
  # Step 5: Load FAISS Database
73
- def load_faiss_db(db_path, embedding_model):
74
- return FAISS.load_local(db_path, embedding_model, allow_dangerous_deserialization=True)
75
-
76
- # Step 6: Load Gemini LLM Model
77
- def load_llm():
78
- return GeminiLLM()
79
-
80
- # Step 7: Set Custom Prompt with Health Profile
81
- CUSTOM_PROMPT_TEMPLATE = """
82
- Use the provided context to answer the user's question.
83
- If the answer is unknown, say you don't know. Do not make up information.
84
- Only respond based on the context.
85
-
86
- Context: {context}
87
- User Health Profile: {health_info}
88
- Question: {question}
89
-
90
- Start your answer directly.
91
- """
92
-
93
- def set_custom_prompt(template):
94
- return PromptTemplate(template=template, input_variables=["context", "question", "health_info"])
95
-
96
- # Step 8: Create Retrieval QA Chain
97
- def create_qa_chain(llm, db):
98
- return RetrievalQA.from_chain_type(
99
- llm=llm,
100
- chain_type="stuff",
101
- retriever=db.as_retriever(search_kwargs={"k": 3}),
102
- return_source_documents=False,
103
- chain_type_kwargs={"prompt": set_custom_prompt(CUSTOM_PROMPT_TEMPLATE)}
104
- )
105
-
106
- # Function to get user health profile via API (placeholder)
107
- def get_user_health_profile():
108
- """
109
- This function should make an API call to get the user's health profile.
110
- Replace this placeholder with your actual API implementation.
111
- """
112
- try:
113
- # Placeholder - replace with your actual API call
114
- return "No health profile available"
115
- except Exception as e:
116
- print(f"Error fetching health profile: {e}")
117
- return "Health profile unavailable"
118
-
119
- # Create and load all models and FAISS (for Gradio)
120
- def prepare_qa_system():
121
- # Load and process PDFs, create FAISS index, etc.
122
- print("πŸ”„ Loading PDFs...")
123
- documents = load_pdf_files(DATA_PATH)
124
-
125
- print("πŸ“„ Creating Chunks...")
126
- text_chunks = create_chunks(documents)
127
-
128
- print("🧠 Generating Embeddings...")
129
- embedding_model = get_embedding_model()
130
-
131
- print("πŸ’Ύ Storing in FAISS...")
132
- db = store_embeddings(text_chunks, embedding_model, DB_FAISS_PATH)
133
-
134
- print("πŸ”„ Loading FAISS Database...")
135
- db = load_faiss_db(DB_FAISS_PATH, embedding_model)
136
-
137
- print("πŸ€– Loading Gemini LLM...")
138
- llm = load_llm()
139
-
140
- print("πŸ”— Creating QA Chain...")
141
- qa_chain = create_qa_chain(llm, db)
142
-
143
- return qa_chain
144
-
145
- # Create the QA system and get the chain ready
146
- qa_chain = prepare_qa_system()
147
-
148
- # Gradio Interface function
149
- def ask_question(query: str):
150
- try:
151
- # Get user's health profile via API
152
- health_info = get_user_health_profile()
153
-
154
- # Prepare inputs for the QA chain
155
- qa_inputs = {
156
- 'query': query,
157
- 'health_info': health_info
158
- }
159
-
160
- response = qa_chain.invoke(qa_inputs)
161
- return response["result"], []
162
- except Exception as e:
163
- return f"Error: {str(e)}", []
 
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
  from langchain_community.vectorstores import FAISS
 
 
 
 
 
6
 
7
  # Set Paths
8
  DATA_PATH = "dataFolder/"
9
  DB_FAISS_PATH = "/tmp/vectorstore/db_faiss"
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Step 1: Load PDF Files
12
+ def load_pdf_files(data_path=DATA_PATH):
13
+ print("πŸ”„ Loading PDFs from:", data_path)
14
  loader = DirectoryLoader(data_path, glob="*.pdf", loader_cls=PyPDFLoader)
15
  documents = loader.load()
16
+ print(f"βœ… Loaded {len(documents)} document(s).")
17
  return documents
18
 
19
  # Step 2: Create Chunks
20
  def create_chunks(documents):
21
+ print("πŸ“„ Creating text chunks...")
22
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
23
  text_chunks = text_splitter.split_documents(documents)
24
+ print(f"βœ… Created {len(text_chunks)} chunk(s).")
25
  return text_chunks
26
 
27
  # Step 3: Generate Embeddings
28
  def get_embedding_model():
29
+ print("🧠 Loading embedding model...")
30
  CACHE_DIR = "/tmp/models_cache"
31
  os.makedirs(CACHE_DIR, exist_ok=True)
32
 
33
  embedding_model = HuggingFaceEmbeddings(
34
  model_name="rishi002/all-MiniLM-L6-v2",
35
+ cache_folder=CACHE_DIR
36
  )
37
+ print("βœ… Embedding model loaded.")
38
  return embedding_model
39
 
40
  # Step 4: Store Embeddings in FAISS
41
+ def store_embeddings(text_chunks, embedding_model, db_path=DB_FAISS_PATH):
42
+ print("πŸ’Ύ Storing embeddings in FAISS...")
43
  db = FAISS.from_documents(text_chunks, embedding_model)
44
  db.save_local(db_path)
45
+ print(f"βœ… FAISS index saved to: {db_path}")
46
  return db
47
 
48
  # Step 5: Load FAISS Database
49
+ def load_faiss_db(db_path=DB_FAISS_PATH, embedding_model=None):
50
+ print("πŸ“¦ Loading FAISS database from:", db_path)
51
+ db = FAISS.load_local(db_path, embedding_model, allow_dangerous_deserialization=True)
52
+ print("βœ… FAISS database loaded.")
53
+ return db