Spaces:
Sleeping
Sleeping
Update rag.py
Browse files
rag.py
CHANGED
@@ -78,36 +78,39 @@ def generate_answer(query: str) -> str:
|
|
78 |
# Load FAISS indices
|
79 |
logger.info("Loading FAISS indices...")
|
80 |
data_vec = load_faiss_index(
|
81 |
-
str(vectors_dir / "
|
82 |
-
"sentence-transformers/all-MiniLM-L12-v2"
|
83 |
-
)
|
84 |
-
med_vec = load_faiss_index(
|
85 |
-
str(vectors_dir / "med_data_vec"),
|
86 |
"sentence-transformers/all-MiniLM-L12-v2"
|
87 |
)
|
88 |
|
89 |
# Create the LLM instance
|
90 |
llm = ChatOpenAI(
|
91 |
-
model="gpt-
|
92 |
temperature=0,
|
93 |
openai_api_key=OPENAI_API_KEY
|
94 |
)
|
95 |
|
96 |
# Define the prompt template
|
97 |
template = """You are a helpful medical information assistant. Use the following pieces of context to answer the medical question at the end.
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
QA_CHAIN_PROMPT = PromptTemplate(
|
113 |
input_variables=["context", "question"],
|
@@ -117,15 +120,11 @@ def generate_answer(query: str) -> str:
|
|
117 |
# Initialize and combine retrievers
|
118 |
logger.info("Setting up retrieval chain...")
|
119 |
data_retriever = data_vec.as_retriever()
|
120 |
-
med_retriever = med_vec.as_retriever()
|
121 |
-
combined_retriever = MergerRetriever(
|
122 |
-
retrievers=[data_retriever, med_retriever]
|
123 |
-
)
|
124 |
|
125 |
# Initialize the RetrievalQA chain
|
126 |
qa_chain = RetrievalQA.from_chain_type(
|
127 |
llm=llm,
|
128 |
-
retriever=
|
129 |
return_source_documents=True,
|
130 |
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
|
131 |
)
|
@@ -138,12 +137,7 @@ def generate_answer(query: str) -> str:
|
|
138 |
# Extracting the relevant documents from the result
|
139 |
extracted_docs = result.get("source_documents", [])
|
140 |
logger.info(f"Extracted documents: {extracted_docs}") # Log the extracted documents
|
141 |
-
|
142 |
-
# New organized printing of extracted documents
|
143 |
-
print("\nExtracted Documents:")
|
144 |
-
for doc in extracted_docs:
|
145 |
-
print(f"Source: {doc.metadata['source']}, Row: {doc.metadata['row']}")
|
146 |
-
print(f"Content: {doc.page_content}\n")
|
147 |
|
148 |
return result["result"]
|
149 |
|
|
|
78 |
# Load FAISS indices
|
79 |
logger.info("Loading FAISS indices...")
|
80 |
data_vec = load_faiss_index(
|
81 |
+
str(vectors_dir / "faiss_v4"),
|
|
|
|
|
|
|
|
|
82 |
"sentence-transformers/all-MiniLM-L12-v2"
|
83 |
)
|
84 |
|
85 |
# Create the LLM instance
|
86 |
llm = ChatOpenAI(
|
87 |
+
model="gpt-4o-mini",
|
88 |
temperature=0,
|
89 |
openai_api_key=OPENAI_API_KEY
|
90 |
)
|
91 |
|
92 |
# Define the prompt template
|
93 |
template = """You are a helpful medical information assistant. Use the following pieces of context to answer the medical question at the end.
|
94 |
+
|
95 |
+
Important notes:
|
96 |
+
- Base your answer strictly on the provided context in a clear and reader-friendly way, using paragraph or bullet points as needed.
|
97 |
+
- If you don't know the answer, say so honestly.
|
98 |
+
- Include relevant disclaimers recommending consulting a healthcare professional for personalized advice.
|
99 |
+
- When suggesting treatment options or medications (if applicable to the question), include:
|
100 |
+
- Drug name
|
101 |
+
- Drug class
|
102 |
+
- Dosage
|
103 |
+
- Frequency and duration
|
104 |
+
- Potential adverse effects
|
105 |
+
- Associated risks and additional recommendations
|
106 |
+
- Highlight if the information is general knowledge or requires professional medical advice.
|
107 |
+
- Prompt for additional details about the disease or patient characteristics if necessary to provide a thorough answer, and encourage users to ask clarifying questions.
|
108 |
+
|
109 |
+
Context: {context}
|
110 |
+
|
111 |
+
Question: {question}
|
112 |
+
|
113 |
+
Medical Information Assistant:"""
|
114 |
|
115 |
QA_CHAIN_PROMPT = PromptTemplate(
|
116 |
input_variables=["context", "question"],
|
|
|
120 |
# Initialize and combine retrievers
|
121 |
logger.info("Setting up retrieval chain...")
|
122 |
data_retriever = data_vec.as_retriever()
|
|
|
|
|
|
|
|
|
123 |
|
124 |
# Initialize the RetrievalQA chain
|
125 |
qa_chain = RetrievalQA.from_chain_type(
|
126 |
llm=llm,
|
127 |
+
retriever=data_retriever,
|
128 |
return_source_documents=True,
|
129 |
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
|
130 |
)
|
|
|
137 |
# Extracting the relevant documents from the result
|
138 |
extracted_docs = result.get("source_documents", [])
|
139 |
logger.info(f"Extracted documents: {extracted_docs}") # Log the extracted documents
|
140 |
+
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
return result["result"]
|
143 |
|