File size: 7,318 Bytes
4423e21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.llms.base import LLM
from groq import Groq
from typing import Any, List, Optional, Dict
from pydantic import Field, BaseModel
import os


class GroqLLM(LLM, BaseModel):
    groq_api_key: str = Field(..., description="Groq API Key")
    model_name: str = Field(default="llama-3.3-70b-versatile", description="Model name to use")
    client: Optional[Any] = None

    def __init__(self, **data):
        super().__init__(**data)
        self.client = Groq(api_key=self.groq_api_key)
    
    @property
    def _llm_type(self) -> str:
        return "groq"

    def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
        completion = self.client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=self.model_name,
            **kwargs
        )
        return completion.choices[0].message.content
    
    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Get the identifying parameters."""
        return {
            "model_name": self.model_name
        }


class AutismResearchBot:
    def __init__(self, groq_api_key: str, index_path: str = "faiss_index"):
        # Initialize the Groq LLM
        self.llm = GroqLLM(
            groq_api_key=groq_api_key,
            model_name="llama-3.3-70b-versatile"  # You can adjust the model as needed
        )
        
        # Load the FAISS index
        self.embeddings = HuggingFaceEmbeddings(
            model_name="pritamdeka/S-PubMedBert-MS-MARCO",
            model_kwargs={'device': 'cpu'}
        )

        self.db = FAISS.load_local(index_path, self.embeddings, allow_dangerous_deserialization = True)
        
        # Initialize memory
        self.memory = ConversationBufferMemory(
            memory_key="chat_history",
            return_messages=True,
            output_key = "answer"
        )
        
        # Create the RAG chain
        self.qa_chain = self._create_qa_chain()
    
    def _create_qa_chain(self):
        # Define the prompt template
        template = """You are an expert AI assistant specialized in autism research and diagnostics. You have access to a database of scientific papers, research documents, and diagnostic tools about autism. Use this knowledge to ask targeted questions, gather relevant information, and provide an accurate, evidence-based assessment of the type of autism the person may have. Finally, offer appropriate therapy recommendations.



Context from scientific papers use these context details only when you will at the end provide therapies don't dicusss these midway betwenn the conversation:



{context}



Chat History:

{chat_history}



Objective:



Ask a series of insightful, diagnostic questions to gather comprehensive information about the individual's or their child's behaviors, challenges, and strengths.

Analyze the responses given to these questions using knowledge from the provided research context.

Determine the type of autism the individual may have based on the gathered data.

Offer evidence-based therapy recommendations tailored to the identified type of autism.

Instructions:



Introduce yourself in the initial message. Please note not to reintroduce yourself in subsequent messages within the same chat.

Each question should be clear, accessible, and empathetic while maintaining scientific accuracy.

Ensure responses and questions demonstrate sensitivity to the diverse experiences of individuals with autism and their families.

Cite specific findings or conclusions from the research context where relevant.

Acknowledge any limitations or uncertainties in the research when analyzing responses.

Aim for conciseness in responses, ensuring clarity and brevity without losing essential details.

Initial Introduction:

β€œβ€"



Hello, I am an AI assistant specialized in autism research and diagnostics. I am here to gather some information to help provide an evidence-based assessment and recommend appropriate therapies.



β€œβ€"



Initial Diagnostic Question:

β€œβ€"



To begin, can you describe some of the behaviors or challenges that prompted you to seek this assessment?



β€œβ€"



Subsequent Questions: (Questions should follow based on the user's answers, aiming to gather necessary details concisely)



question :

{question}



Answer:"""

        PROMPT = PromptTemplate(
            template=template,
            input_variables=["context", "chat_history", "question"]
        )
        
        # Create the chain
        chain = ConversationalRetrievalChain.from_llm(
            llm=self.llm,
            chain_type="stuff",
            retriever=self.db.as_retriever(
                search_type="similarity",
                search_kwargs={"k": 3}
            ),
            memory=self.memory,
            combine_docs_chain_kwargs={
                "prompt": PROMPT
            },
            # verbose = True,
            return_source_documents=True
        )
        
        return chain
    
    def answer_question(self, question: str):
        """

        Process a question and return the answer along with source documents

        """
        result = self.qa_chain({"question": question})
        
        # Extract answer and sources
        answer = result['answer']
        sources = result['source_documents']
        
        # Format sources for reference
        source_info = []
        for doc in sources:
            source_info.append({
                'content': doc.page_content[:200] + "...",
                'metadata': doc.metadata
            })
        
        return {
            'answer': answer,
            'sources': source_info
        }

# Example usage
if __name__ == "__main__":
    groq_api_key = "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"
    
    # Initialize the bot
    bot = AutismResearchBot(groq_api_key=groq_api_key)
    
    # Example question
    # question = "What are the latest findings regarding sensory processing in autism?"
    # response = bot.answer_question(question)
    while(1):
        print("*"*40)
        print("*"*40)
        print("*"*40)
        question = input("Enter your question (or 'quit' to exit): ")
        if question.lower() == 'quit':
            break
        response = bot.answer_question(question)
        print("\nAnswer:")
        print(response['answer'])
        # print("\nSources used:")
        # for source in response['sources']:
        #     print(f"\nSource metadata: {source['metadata']}")
        #     print(f"Content preview: {source['content']}")
        #     bot.answer_question
    # Print response
    # print("\nAnswer:")
    # print(response['answer'])
    # print("\nSources used:")
    # for source in response['sources']:
    #     print(f"\nSource metadata: {source['metadata']}")
    #     print(f"Content preview: {source['content']}")