Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,20 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
-
from langchain_community.vectorstores import FAISS
|
4 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings
|
5 |
from langchain.chains import RetrievalQA
|
6 |
from langchain_core.prompts import PromptTemplate
|
7 |
-
from sentence_transformers import SentenceTransformer
|
8 |
-
from collections import OrderedDict
|
9 |
-
import google.generativeai as genai
|
10 |
from langchain.llms.base import LLM
|
|
|
11 |
from typing import Optional, List
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# Constants
|
14 |
DATA_PATH = "dataFolder/"
|
@@ -16,71 +22,57 @@ DB_FAISS_PATH = "/tmp/vectorstore/db_faiss"
|
|
16 |
CACHE_DIR = "/tmp/models_cache"
|
17 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
18 |
|
19 |
-
# Google AI API setup
|
20 |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
|
21 |
if not GOOGLE_API_KEY:
|
22 |
-
print("
|
23 |
print("Please set your Google API key in Hugging Face Spaces secrets.")
|
24 |
else:
|
25 |
genai.configure(api_key=GOOGLE_API_KEY)
|
26 |
|
27 |
-
# Load
|
28 |
-
embedding_model = HuggingFaceEmbeddings(
|
29 |
-
model_name="rishi002/all-MiniLM-L6-v2",
|
30 |
-
cache_folder=CACHE_DIR
|
31 |
-
)
|
32 |
-
|
33 |
-
# Load or create FAISS database
|
34 |
def load_or_create_faiss():
|
|
|
35 |
if not os.path.exists(DB_FAISS_PATH):
|
36 |
-
print("π Creating
|
37 |
-
from embeddings import load_pdf_files, create_chunks # Your custom chunking logic
|
38 |
-
|
39 |
documents = load_pdf_files(DATA_PATH)
|
40 |
text_chunks = create_chunks(documents)
|
41 |
-
|
42 |
-
db = FAISS.from_documents(text_chunks, embedding_model)
|
43 |
-
db.save_local(DB_FAISS_PATH)
|
44 |
else:
|
45 |
-
print("β
FAISS
|
46 |
-
|
47 |
-
return
|
48 |
|
49 |
db = load_or_create_faiss()
|
50 |
|
51 |
-
# Custom Gemini LLM wrapper for LangChain
|
52 |
-
# Custom Gemini LLM wrapper for LangChain - Fixed for Hugging Face
|
53 |
class GeminiLLM(LLM):
|
54 |
model_name: str = "gemini-2.0-flash"
|
55 |
-
|
56 |
class Config:
|
57 |
-
"""Configuration for this pydantic object."""
|
58 |
extra = 'forbid'
|
59 |
arbitrary_types_allowed = True
|
60 |
-
|
61 |
def __init__(self, model_name: str = "gemini-2.0-flash", **kwargs):
|
62 |
-
# Initialize only with pydantic-defined fields
|
63 |
super().__init__(model_name=model_name, **kwargs)
|
64 |
-
|
65 |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
66 |
try:
|
67 |
-
# Use local variable, not self.model
|
68 |
model = genai.GenerativeModel(self.model_name)
|
69 |
response = model.generate_content(prompt)
|
70 |
return response.text
|
71 |
except Exception as e:
|
72 |
return f"Error generating response: {str(e)}"
|
73 |
-
|
74 |
@property
|
75 |
def _identifying_params(self):
|
76 |
return {"model_name": self.model_name}
|
77 |
-
|
78 |
@property
|
79 |
def _llm_type(self):
|
80 |
return "gemini"
|
81 |
|
82 |
-
|
83 |
-
# Updated prompt template with health profile
|
84 |
CUSTOM_PROMPT_TEMPLATE = """
|
85 |
Use the pieces of information provided in the context to answer the user's question.
|
86 |
If you don't know the answer, just say that you don't know. Don't make up an answer.
|
@@ -96,17 +88,13 @@ Question: {question}
|
|
96 |
Start the answer directly.
|
97 |
"""
|
98 |
|
99 |
-
#
|
100 |
-
|
101 |
-
# Create qa_chain using Gemini
|
102 |
def create_qa_chain():
|
103 |
prompt = PromptTemplate(
|
104 |
-
template=CUSTOM_PROMPT_TEMPLATE,
|
105 |
input_variables=["context", "question", "health_info"]
|
106 |
)
|
107 |
-
|
108 |
gemini_llm = GeminiLLM()
|
109 |
-
|
110 |
return RetrievalQA.from_chain_type(
|
111 |
llm=gemini_llm,
|
112 |
chain_type="stuff",
|
@@ -115,41 +103,38 @@ def create_qa_chain():
|
|
115 |
chain_type_kwargs={'prompt': prompt}
|
116 |
)
|
117 |
|
118 |
-
# Main QA Chain
|
119 |
qa_chain = create_qa_chain()
|
120 |
|
121 |
-
#
|
122 |
def ask_question(query: str, health_info: str = "No health profile provided"):
|
123 |
try:
|
124 |
-
# Prepare inputs for the QA chain
|
125 |
qa_inputs = {
|
126 |
'query': query,
|
127 |
'health_info': health_info
|
128 |
}
|
129 |
-
|
130 |
-
# Get response from QA chain
|
131 |
response = qa_chain.invoke(qa_inputs)
|
132 |
result = response["result"]
|
133 |
-
|
134 |
-
#
|
135 |
sentences = [s.strip() for s in result.split('.') if s.strip()]
|
136 |
unique_sentences = list(OrderedDict.fromkeys(sentences))
|
137 |
cleaned_result = '. '.join(unique_sentences) + '.'
|
138 |
-
|
139 |
return cleaned_result, []
|
140 |
-
|
141 |
except Exception as e:
|
142 |
return f"Error: {str(e)}", []
|
143 |
|
144 |
-
# Gradio Interface
|
145 |
iface = gr.Interface(
|
146 |
-
fn=ask_question,
|
147 |
inputs=[
|
148 |
gr.Textbox(label="Question", placeholder="Enter your question here..."),
|
149 |
gr.Textbox(label="Health Profile", placeholder="Enter your health information (optional)...", value="No health profile provided")
|
150 |
-
],
|
151 |
outputs=["text", "json"],
|
152 |
title="Medical RAG Chatbot",
|
153 |
description="Ask medical questions and optionally provide your health profile for personalized responses."
|
154 |
)
|
155 |
-
|
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
|
|
|
|
3 |
from langchain.chains import RetrievalQA
|
4 |
from langchain_core.prompts import PromptTemplate
|
|
|
|
|
|
|
5 |
from langchain.llms.base import LLM
|
6 |
+
from collections import OrderedDict
|
7 |
from typing import Optional, List
|
8 |
+
import google.generativeai as genai
|
9 |
+
|
10 |
+
# Custom utility functions
|
11 |
+
from embeddings import (
|
12 |
+
load_pdf_files,
|
13 |
+
create_chunks,
|
14 |
+
get_embedding_model,
|
15 |
+
store_embeddings,
|
16 |
+
load_faiss_db
|
17 |
+
)
|
18 |
|
19 |
# Constants
|
20 |
DATA_PATH = "dataFolder/"
|
|
|
22 |
CACHE_DIR = "/tmp/models_cache"
|
23 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
24 |
|
25 |
+
# Google AI API setup
|
26 |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
|
27 |
if not GOOGLE_API_KEY:
|
28 |
+
print("β οΈ GOOGLE_API_KEY not found in environment variables!")
|
29 |
print("Please set your Google API key in Hugging Face Spaces secrets.")
|
30 |
else:
|
31 |
genai.configure(api_key=GOOGLE_API_KEY)
|
32 |
|
33 |
+
# Load or create FAISS vector store
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def load_or_create_faiss():
|
35 |
+
embedding_model = get_embedding_model()
|
36 |
if not os.path.exists(DB_FAISS_PATH):
|
37 |
+
print("π FAISS index not found. Creating new index...")
|
|
|
|
|
38 |
documents = load_pdf_files(DATA_PATH)
|
39 |
text_chunks = create_chunks(documents)
|
40 |
+
db = store_embeddings(text_chunks, embedding_model, DB_FAISS_PATH)
|
|
|
|
|
41 |
else:
|
42 |
+
print("β
Existing FAISS index found. Loading it...")
|
43 |
+
db = load_faiss_db(DB_FAISS_PATH, embedding_model)
|
44 |
+
return db
|
45 |
|
46 |
db = load_or_create_faiss()
|
47 |
|
48 |
+
# β
Custom Gemini LLM wrapper for LangChain
|
|
|
49 |
class GeminiLLM(LLM):
|
50 |
model_name: str = "gemini-2.0-flash"
|
51 |
+
|
52 |
class Config:
|
|
|
53 |
extra = 'forbid'
|
54 |
arbitrary_types_allowed = True
|
55 |
+
|
56 |
def __init__(self, model_name: str = "gemini-2.0-flash", **kwargs):
|
|
|
57 |
super().__init__(model_name=model_name, **kwargs)
|
58 |
+
|
59 |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
60 |
try:
|
|
|
61 |
model = genai.GenerativeModel(self.model_name)
|
62 |
response = model.generate_content(prompt)
|
63 |
return response.text
|
64 |
except Exception as e:
|
65 |
return f"Error generating response: {str(e)}"
|
66 |
+
|
67 |
@property
|
68 |
def _identifying_params(self):
|
69 |
return {"model_name": self.model_name}
|
70 |
+
|
71 |
@property
|
72 |
def _llm_type(self):
|
73 |
return "gemini"
|
74 |
|
75 |
+
# Prompt template with user health profile
|
|
|
76 |
CUSTOM_PROMPT_TEMPLATE = """
|
77 |
Use the pieces of information provided in the context to answer the user's question.
|
78 |
If you don't know the answer, just say that you don't know. Don't make up an answer.
|
|
|
88 |
Start the answer directly.
|
89 |
"""
|
90 |
|
91 |
+
# QA Chain constructor
|
|
|
|
|
92 |
def create_qa_chain():
|
93 |
prompt = PromptTemplate(
|
94 |
+
template=CUSTOM_PROMPT_TEMPLATE,
|
95 |
input_variables=["context", "question", "health_info"]
|
96 |
)
|
|
|
97 |
gemini_llm = GeminiLLM()
|
|
|
98 |
return RetrievalQA.from_chain_type(
|
99 |
llm=gemini_llm,
|
100 |
chain_type="stuff",
|
|
|
103 |
chain_type_kwargs={'prompt': prompt}
|
104 |
)
|
105 |
|
|
|
106 |
qa_chain = create_qa_chain()
|
107 |
|
108 |
+
# Function to handle question asking
|
109 |
def ask_question(query: str, health_info: str = "No health profile provided"):
|
110 |
try:
|
|
|
111 |
qa_inputs = {
|
112 |
'query': query,
|
113 |
'health_info': health_info
|
114 |
}
|
|
|
|
|
115 |
response = qa_chain.invoke(qa_inputs)
|
116 |
result = response["result"]
|
117 |
+
|
118 |
+
# Deduplicate output
|
119 |
sentences = [s.strip() for s in result.split('.') if s.strip()]
|
120 |
unique_sentences = list(OrderedDict.fromkeys(sentences))
|
121 |
cleaned_result = '. '.join(unique_sentences) + '.'
|
122 |
+
|
123 |
return cleaned_result, []
|
124 |
+
|
125 |
except Exception as e:
|
126 |
return f"Error: {str(e)}", []
|
127 |
|
128 |
+
# Gradio Interface
|
129 |
iface = gr.Interface(
|
130 |
+
fn=ask_question,
|
131 |
inputs=[
|
132 |
gr.Textbox(label="Question", placeholder="Enter your question here..."),
|
133 |
gr.Textbox(label="Health Profile", placeholder="Enter your health information (optional)...", value="No health profile provided")
|
134 |
+
],
|
135 |
outputs=["text", "json"],
|
136 |
title="Medical RAG Chatbot",
|
137 |
description="Ask medical questions and optionally provide your health profile for personalized responses."
|
138 |
)
|
139 |
+
|
140 |
+
iface.launch(share=True)
|