rishi002 commited on
Commit
d7ea88b
Β·
verified Β·
1 Parent(s): 9904e1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -54
app.py CHANGED
@@ -1,14 +1,20 @@
1
  import os
2
  import gradio as gr
3
- from langchain_community.vectorstores import FAISS
4
- from langchain_community.embeddings import HuggingFaceEmbeddings
5
  from langchain.chains import RetrievalQA
6
  from langchain_core.prompts import PromptTemplate
7
- from sentence_transformers import SentenceTransformer
8
- from collections import OrderedDict
9
- import google.generativeai as genai
10
  from langchain.llms.base import LLM
 
11
  from typing import Optional, List
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Constants
14
  DATA_PATH = "dataFolder/"
@@ -16,71 +22,57 @@ DB_FAISS_PATH = "/tmp/vectorstore/db_faiss"
16
  CACHE_DIR = "/tmp/models_cache"
17
  os.makedirs(CACHE_DIR, exist_ok=True)
18
 
19
- # Google AI API setup with better error handling
20
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
21
  if not GOOGLE_API_KEY:
22
- print("Warning: GOOGLE_API_KEY not found in environment variables!")
23
  print("Please set your Google API key in Hugging Face Spaces secrets.")
24
  else:
25
  genai.configure(api_key=GOOGLE_API_KEY)
26
 
27
- # Load the embedding model
28
- embedding_model = HuggingFaceEmbeddings(
29
- model_name="rishi002/all-MiniLM-L6-v2",
30
- cache_folder=CACHE_DIR
31
- )
32
-
33
- # Load or create FAISS database
34
  def load_or_create_faiss():
 
35
  if not os.path.exists(DB_FAISS_PATH):
36
- print("πŸ”„ Creating FAISS Database...")
37
- from embeddings import load_pdf_files, create_chunks # Your custom chunking logic
38
-
39
  documents = load_pdf_files(DATA_PATH)
40
  text_chunks = create_chunks(documents)
41
-
42
- db = FAISS.from_documents(text_chunks, embedding_model)
43
- db.save_local(DB_FAISS_PATH)
44
  else:
45
- print("βœ… FAISS Database Exists. Loading...")
46
-
47
- return FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
48
 
49
  db = load_or_create_faiss()
50
 
51
- # Custom Gemini LLM wrapper for LangChain - Fixed for Hugging Face
52
- # Custom Gemini LLM wrapper for LangChain - Fixed for Hugging Face
53
  class GeminiLLM(LLM):
54
  model_name: str = "gemini-2.0-flash"
55
-
56
  class Config:
57
- """Configuration for this pydantic object."""
58
  extra = 'forbid'
59
  arbitrary_types_allowed = True
60
-
61
  def __init__(self, model_name: str = "gemini-2.0-flash", **kwargs):
62
- # Initialize only with pydantic-defined fields
63
  super().__init__(model_name=model_name, **kwargs)
64
-
65
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
66
  try:
67
- # Use local variable, not self.model
68
  model = genai.GenerativeModel(self.model_name)
69
  response = model.generate_content(prompt)
70
  return response.text
71
  except Exception as e:
72
  return f"Error generating response: {str(e)}"
73
-
74
  @property
75
  def _identifying_params(self):
76
  return {"model_name": self.model_name}
77
-
78
  @property
79
  def _llm_type(self):
80
  return "gemini"
81
 
82
-
83
- # Updated prompt template with health profile
84
  CUSTOM_PROMPT_TEMPLATE = """
85
  Use the pieces of information provided in the context to answer the user's question.
86
  If you don't know the answer, just say that you don't know. Don't make up an answer.
@@ -96,17 +88,13 @@ Question: {question}
96
  Start the answer directly.
97
  """
98
 
99
- # No need for API function anymore - health info will be passed as parameter
100
-
101
- # Create qa_chain using Gemini
102
  def create_qa_chain():
103
  prompt = PromptTemplate(
104
- template=CUSTOM_PROMPT_TEMPLATE,
105
  input_variables=["context", "question", "health_info"]
106
  )
107
-
108
  gemini_llm = GeminiLLM()
109
-
110
  return RetrievalQA.from_chain_type(
111
  llm=gemini_llm,
112
  chain_type="stuff",
@@ -115,41 +103,38 @@ def create_qa_chain():
115
  chain_type_kwargs={'prompt': prompt}
116
  )
117
 
118
- # Main QA Chain
119
  qa_chain = create_qa_chain()
120
 
121
- # Modified ask_question function with health_info as parameter
122
  def ask_question(query: str, health_info: str = "No health profile provided"):
123
  try:
124
- # Prepare inputs for the QA chain
125
  qa_inputs = {
126
  'query': query,
127
  'health_info': health_info
128
  }
129
-
130
- # Get response from QA chain
131
  response = qa_chain.invoke(qa_inputs)
132
  result = response["result"]
133
-
134
- # Clean up response to remove duplicates
135
  sentences = [s.strip() for s in result.split('.') if s.strip()]
136
  unique_sentences = list(OrderedDict.fromkeys(sentences))
137
  cleaned_result = '. '.join(unique_sentences) + '.'
138
-
139
  return cleaned_result, []
140
-
141
  except Exception as e:
142
  return f"Error: {str(e)}", []
143
 
144
- # Gradio Interface with two inputs
145
  iface = gr.Interface(
146
- fn=ask_question,
147
  inputs=[
148
  gr.Textbox(label="Question", placeholder="Enter your question here..."),
149
  gr.Textbox(label="Health Profile", placeholder="Enter your health information (optional)...", value="No health profile provided")
150
- ],
151
  outputs=["text", "json"],
152
  title="Medical RAG Chatbot",
153
  description="Ask medical questions and optionally provide your health profile for personalized responses."
154
  )
155
- iface.launch(share=True)
 
 
1
  import os
2
  import gradio as gr
 
 
3
  from langchain.chains import RetrievalQA
4
  from langchain_core.prompts import PromptTemplate
 
 
 
5
  from langchain.llms.base import LLM
6
+ from collections import OrderedDict
7
  from typing import Optional, List
8
+ import google.generativeai as genai
9
+
10
+ # Custom utility functions
11
+ from embeddings import (
12
+ load_pdf_files,
13
+ create_chunks,
14
+ get_embedding_model,
15
+ store_embeddings,
16
+ load_faiss_db
17
+ )
18
 
19
  # Constants
20
  DATA_PATH = "dataFolder/"
 
22
  CACHE_DIR = "/tmp/models_cache"
23
  os.makedirs(CACHE_DIR, exist_ok=True)
24
 
25
+ # Google AI API setup
26
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
27
  if not GOOGLE_API_KEY:
28
+ print("⚠️ GOOGLE_API_KEY not found in environment variables!")
29
  print("Please set your Google API key in Hugging Face Spaces secrets.")
30
  else:
31
  genai.configure(api_key=GOOGLE_API_KEY)
32
 
33
+ # Load or create FAISS vector store
 
 
 
 
 
 
34
  def load_or_create_faiss():
35
+ embedding_model = get_embedding_model()
36
  if not os.path.exists(DB_FAISS_PATH):
37
+ print("πŸ”„ FAISS index not found. Creating new index...")
 
 
38
  documents = load_pdf_files(DATA_PATH)
39
  text_chunks = create_chunks(documents)
40
+ db = store_embeddings(text_chunks, embedding_model, DB_FAISS_PATH)
 
 
41
  else:
42
+ print("βœ… Existing FAISS index found. Loading it...")
43
+ db = load_faiss_db(DB_FAISS_PATH, embedding_model)
44
+ return db
45
 
46
  db = load_or_create_faiss()
47
 
48
+ # βœ… Custom Gemini LLM wrapper for LangChain
 
49
  class GeminiLLM(LLM):
50
  model_name: str = "gemini-2.0-flash"
51
+
52
  class Config:
 
53
  extra = 'forbid'
54
  arbitrary_types_allowed = True
55
+
56
  def __init__(self, model_name: str = "gemini-2.0-flash", **kwargs):
 
57
  super().__init__(model_name=model_name, **kwargs)
58
+
59
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
60
  try:
 
61
  model = genai.GenerativeModel(self.model_name)
62
  response = model.generate_content(prompt)
63
  return response.text
64
  except Exception as e:
65
  return f"Error generating response: {str(e)}"
66
+
67
  @property
68
  def _identifying_params(self):
69
  return {"model_name": self.model_name}
70
+
71
  @property
72
  def _llm_type(self):
73
  return "gemini"
74
 
75
+ # Prompt template with user health profile
 
76
  CUSTOM_PROMPT_TEMPLATE = """
77
  Use the pieces of information provided in the context to answer the user's question.
78
  If you don't know the answer, just say that you don't know. Don't make up an answer.
 
88
  Start the answer directly.
89
  """
90
 
91
+ # QA Chain constructor
 
 
92
  def create_qa_chain():
93
  prompt = PromptTemplate(
94
+ template=CUSTOM_PROMPT_TEMPLATE,
95
  input_variables=["context", "question", "health_info"]
96
  )
 
97
  gemini_llm = GeminiLLM()
 
98
  return RetrievalQA.from_chain_type(
99
  llm=gemini_llm,
100
  chain_type="stuff",
 
103
  chain_type_kwargs={'prompt': prompt}
104
  )
105
 
 
106
  qa_chain = create_qa_chain()
107
 
108
+ # Function to handle question asking
109
  def ask_question(query: str, health_info: str = "No health profile provided"):
110
  try:
 
111
  qa_inputs = {
112
  'query': query,
113
  'health_info': health_info
114
  }
 
 
115
  response = qa_chain.invoke(qa_inputs)
116
  result = response["result"]
117
+
118
+ # Deduplicate output
119
  sentences = [s.strip() for s in result.split('.') if s.strip()]
120
  unique_sentences = list(OrderedDict.fromkeys(sentences))
121
  cleaned_result = '. '.join(unique_sentences) + '.'
122
+
123
  return cleaned_result, []
124
+
125
  except Exception as e:
126
  return f"Error: {str(e)}", []
127
 
128
+ # Gradio Interface
129
  iface = gr.Interface(
130
+ fn=ask_question,
131
  inputs=[
132
  gr.Textbox(label="Question", placeholder="Enter your question here..."),
133
  gr.Textbox(label="Health Profile", placeholder="Enter your health information (optional)...", value="No health profile provided")
134
+ ],
135
  outputs=["text", "json"],
136
  title="Medical RAG Chatbot",
137
  description="Ask medical questions and optionally provide your health profile for personalized responses."
138
  )
139
+
140
+ iface.launch(share=True)