Spaces:
Sleeping
Sleeping
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts import PromptTemplate | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.llms.base import LLM | |
from groq import Groq | |
from typing import Any, List, Optional, Dict | |
from pydantic import Field, BaseModel | |
import os | |
class GroqLLM(LLM, BaseModel): | |
groq_api_key: str = Field(..., description="Groq API Key") | |
model_name: str = Field(default="llama-3.3-70b-versatile", description="Model name to use") | |
client: Optional[Any] = None | |
def __init__(self, **data): | |
super().__init__(**data) | |
self.client = Groq(api_key=self.groq_api_key) | |
def _llm_type(self) -> str: | |
return "groq" | |
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str: | |
completion = self.client.chat.completions.create( | |
messages=[{"role": "user", "content": prompt}], | |
model=self.model_name, | |
**kwargs | |
) | |
return completion.choices[0].message.content | |
def _identifying_params(self) -> Dict[str, Any]: | |
"""Get the identifying parameters.""" | |
return { | |
"model_name": self.model_name | |
} | |
class AutismResearchBot: | |
def __init__(self, groq_api_key: str, index_path: str = "index.faiss"): | |
# Initialize the Groq LLM | |
self.llm = GroqLLM( | |
groq_api_key=groq_api_key, | |
model_name="llama-3.3-70b-versatile" # You can adjust the model as needed | |
) | |
# Load the FAISS index | |
self.embeddings = HuggingFaceEmbeddings( | |
model_name="pritamdeka/S-PubMedBert-MS-MARCO", | |
model_kwargs={'device': 'cpu'} | |
) | |
self.db = FAISS.load_local("./", self.embeddings, allow_dangerous_deserialization = True) | |
# Initialize memory | |
self.memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True, | |
output_key = "answer" | |
) | |
# Create the RAG chain | |
self.qa_chain = self._create_qa_chain() | |
def _create_qa_chain(self): | |
# Define the prompt template | |
template = """You are an expert AI assistant specialized in autism research and diagnostics. You have access to a database of scientific papers, research documents, and diagnostic tools about autism. Use this knowledge to ask targeted questions, gather relevant information, and provide an accurate, evidence-based assessment of the type of autism the person may have. Finally, offer appropriate therapy recommendations. | |
Context from scientific papers use these context details only when you will at the end provide therapies don't dicusss these midway betwenn the conversation: | |
{context} | |
Chat History: | |
{chat_history} | |
Objective: | |
Ask a series of insightful, diagnostic questions to gather comprehensive information about the individual's or their child's behaviors, challenges, and strengths. | |
Analyze the responses given to these questions using knowledge from the provided research context. | |
Determine the type of autism the individual may have based on the gathered data. | |
Offer evidence-based therapy recommendations tailored to the identified type of autism. | |
Instructions: | |
Introduce yourself in the initial message. Please note not to reintroduce yourself in subsequent messages within the same chat. | |
Each question should be clear, accessible, and empathetic while maintaining scientific accuracy. | |
Ensure responses and questions demonstrate sensitivity to the diverse experiences of individuals with autism and their families. | |
Cite specific findings or conclusions from the research context where relevant. | |
Acknowledge any limitations or uncertainties in the research when analyzing responses. | |
Aim for conciseness in responses, ensuring clarity and brevity without losing essential details. | |
Initial Introduction: | |
ββ" | |
Hello, I am an AI assistant specialized in autism research and diagnostics. I am here to gather some information to help provide an evidence-based assessment and recommend appropriate therapies. | |
ββ" | |
Initial Diagnostic Question: | |
ββ" | |
To begin, can you describe some of the behaviors or challenges that prompted you to seek this assessment? | |
ββ" | |
Subsequent Questions: (Questions should follow based on the user's answers, aiming to gather necessary details concisely) | |
question : | |
{question} | |
Answer:""" | |
PROMPT = PromptTemplate( | |
template=template, | |
input_variables=["context", "chat_history", "question"] | |
) | |
# Create the chain | |
chain = ConversationalRetrievalChain.from_llm( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=self.db.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 3} | |
), | |
memory=self.memory, | |
combine_docs_chain_kwargs={ | |
"prompt": PROMPT | |
}, | |
# verbose = True, | |
return_source_documents=True | |
) | |
return chain | |
def answer_question(self, question: str): | |
""" | |
Process a question and return the answer along with source documents | |
""" | |
result = self.qa_chain({"question": question}) | |
# Extract answer and sources | |
answer = result['answer'] | |
sources = result['source_documents'] | |
# Format sources for reference | |
source_info = [] | |
for doc in sources: | |
source_info.append({ | |
'content': doc.page_content[:200] + "...", | |
'metadata': doc.metadata | |
}) | |
return { | |
'answer': answer, | |
'sources': source_info | |
} |