Muhammadbilal10101 commited on
Commit
8bcde57
Β·
verified Β·
1 Parent(s): 9a7c673

Update autism_chatbot.py

Browse files
Files changed (1) hide show
  1. autism_chatbot.py +193 -193
autism_chatbot.py CHANGED
@@ -1,194 +1,194 @@
1
- from langchain.chains import ConversationalRetrievalChain
2
- from langchain.memory import ConversationBufferMemory
3
- from langchain.prompts import PromptTemplate
4
- from langchain_community.embeddings import HuggingFaceEmbeddings
5
- from langchain_community.vectorstores import FAISS
6
- from langchain.llms.base import LLM
7
- from groq import Groq
8
- from typing import Any, List, Optional, Dict
9
- from pydantic import Field, BaseModel
10
- import os
11
-
12
-
13
- class GroqLLM(LLM, BaseModel):
14
- groq_api_key: str = Field(..., description="Groq API Key")
15
- model_name: str = Field(default="llama-3.3-70b-versatile", description="Model name to use")
16
- client: Optional[Any] = None
17
-
18
- def __init__(self, **data):
19
- super().__init__(**data)
20
- self.client = Groq(api_key=self.groq_api_key)
21
-
22
- @property
23
- def _llm_type(self) -> str:
24
- return "groq"
25
-
26
- def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
27
- completion = self.client.chat.completions.create(
28
- messages=[{"role": "user", "content": prompt}],
29
- model=self.model_name,
30
- **kwargs
31
- )
32
- return completion.choices[0].message.content
33
-
34
- @property
35
- def _identifying_params(self) -> Dict[str, Any]:
36
- """Get the identifying parameters."""
37
- return {
38
- "model_name": self.model_name
39
- }
40
-
41
-
42
- class AutismResearchBot:
43
- def __init__(self, groq_api_key: str, index_path: str = "faiss_index"):
44
- # Initialize the Groq LLM
45
- self.llm = GroqLLM(
46
- groq_api_key=groq_api_key,
47
- model_name="llama-3.3-70b-versatile" # You can adjust the model as needed
48
- )
49
-
50
- # Load the FAISS index
51
- self.embeddings = HuggingFaceEmbeddings(
52
- model_name="pritamdeka/S-PubMedBert-MS-MARCO",
53
- model_kwargs={'device': 'cpu'}
54
- )
55
-
56
- self.db = FAISS.load_local(index_path, self.embeddings, allow_dangerous_deserialization = True)
57
-
58
- # Initialize memory
59
- self.memory = ConversationBufferMemory(
60
- memory_key="chat_history",
61
- return_messages=True,
62
- output_key = "answer"
63
- )
64
-
65
- # Create the RAG chain
66
- self.qa_chain = self._create_qa_chain()
67
-
68
- def _create_qa_chain(self):
69
- # Define the prompt template
70
- template = """You are an expert AI assistant specialized in autism research and diagnostics. You have access to a database of scientific papers, research documents, and diagnostic tools about autism. Use this knowledge to ask targeted questions, gather relevant information, and provide an accurate, evidence-based assessment of the type of autism the person may have. Finally, offer appropriate therapy recommendations.
71
-
72
- Context from scientific papers use these context details only when you will at the end provide therapies don't dicusss these midway betwenn the conversation:
73
-
74
- {context}
75
-
76
- Chat History:
77
- {chat_history}
78
-
79
- Objective:
80
-
81
- Ask a series of insightful, diagnostic questions to gather comprehensive information about the individual's or their child's behaviors, challenges, and strengths.
82
- Analyze the responses given to these questions using knowledge from the provided research context.
83
- Determine the type of autism the individual may have based on the gathered data.
84
- Offer evidence-based therapy recommendations tailored to the identified type of autism.
85
- Instructions:
86
-
87
- Introduce yourself in the initial message. Please note not to reintroduce yourself in subsequent messages within the same chat.
88
- Each question should be clear, accessible, and empathetic while maintaining scientific accuracy.
89
- Ensure responses and questions demonstrate sensitivity to the diverse experiences of individuals with autism and their families.
90
- Cite specific findings or conclusions from the research context where relevant.
91
- Acknowledge any limitations or uncertainties in the research when analyzing responses.
92
- Aim for conciseness in responses, ensuring clarity and brevity without losing essential details.
93
- Initial Introduction:
94
- β€œβ€"
95
-
96
- Hello, I am an AI assistant specialized in autism research and diagnostics. I am here to gather some information to help provide an evidence-based assessment and recommend appropriate therapies.
97
-
98
- β€œβ€"
99
-
100
- Initial Diagnostic Question:
101
- β€œβ€"
102
-
103
- To begin, can you describe some of the behaviors or challenges that prompted you to seek this assessment?
104
-
105
- β€œβ€"
106
-
107
- Subsequent Questions: (Questions should follow based on the user's answers, aiming to gather necessary details concisely)
108
-
109
- question :
110
- {question}
111
-
112
- Answer:"""
113
-
114
- PROMPT = PromptTemplate(
115
- template=template,
116
- input_variables=["context", "chat_history", "question"]
117
- )
118
-
119
- # Create the chain
120
- chain = ConversationalRetrievalChain.from_llm(
121
- llm=self.llm,
122
- chain_type="stuff",
123
- retriever=self.db.as_retriever(
124
- search_type="similarity",
125
- search_kwargs={"k": 3}
126
- ),
127
- memory=self.memory,
128
- combine_docs_chain_kwargs={
129
- "prompt": PROMPT
130
- },
131
- # verbose = True,
132
- return_source_documents=True
133
- )
134
-
135
- return chain
136
-
137
- def answer_question(self, question: str):
138
- """
139
- Process a question and return the answer along with source documents
140
- """
141
- result = self.qa_chain({"question": question})
142
-
143
- # Extract answer and sources
144
- answer = result['answer']
145
- sources = result['source_documents']
146
-
147
- # Format sources for reference
148
- source_info = []
149
- for doc in sources:
150
- source_info.append({
151
- 'content': doc.page_content[:200] + "...",
152
- 'metadata': doc.metadata
153
- })
154
-
155
- return {
156
- 'answer': answer,
157
- 'sources': source_info
158
- }
159
-
160
- # Example usage
161
- if __name__ == "__main__":
162
- groq_api_key = "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"
163
-
164
- # Initialize the bot
165
- bot = AutismResearchBot(groq_api_key=groq_api_key)
166
-
167
- # Example question
168
- # question = "What are the latest findings regarding sensory processing in autism?"
169
- # response = bot.answer_question(question)
170
- while(1):
171
- print("*"*40)
172
- print("*"*40)
173
- print("*"*40)
174
- question = input("Enter your question (or 'quit' to exit): ")
175
- if question.lower() == 'quit':
176
- break
177
- response = bot.answer_question(question)
178
- print("\nAnswer:")
179
- print(response['answer'])
180
- # print("\nSources used:")
181
- # for source in response['sources']:
182
- # print(f"\nSource metadata: {source['metadata']}")
183
- # print(f"Content preview: {source['content']}")
184
- # bot.answer_question
185
- # Print response
186
- # print("\nAnswer:")
187
- # print(response['answer'])
188
- # print("\nSources used:")
189
- # for source in response['sources']:
190
- # print(f"\nSource metadata: {source['metadata']}")
191
- # print(f"Content preview: {source['content']}")
192
-
193
-
194
 
 
1
+ from langchain.chains import ConversationalRetrievalChain
2
+ from langchain.memory import ConversationBufferMemory
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain.llms.base import LLM
7
+ from groq import Groq
8
+ from typing import Any, List, Optional, Dict
9
+ from pydantic import Field, BaseModel
10
+ import os
11
+
12
+
13
+ class GroqLLM(LLM, BaseModel):
14
+ groq_api_key: str = Field(..., description="Groq API Key")
15
+ model_name: str = Field(default="llama-3.3-70b-versatile", description="Model name to use")
16
+ client: Optional[Any] = None
17
+
18
+ def __init__(self, **data):
19
+ super().__init__(**data)
20
+ self.client = Groq(api_key=self.groq_api_key)
21
+
22
+ @property
23
+ def _llm_type(self) -> str:
24
+ return "groq"
25
+
26
+ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
27
+ completion = self.client.chat.completions.create(
28
+ messages=[{"role": "user", "content": prompt}],
29
+ model=self.model_name,
30
+ **kwargs
31
+ )
32
+ return completion.choices[0].message.content
33
+
34
+ @property
35
+ def _identifying_params(self) -> Dict[str, Any]:
36
+ """Get the identifying parameters."""
37
+ return {
38
+ "model_name": self.model_name
39
+ }
40
+
41
+
42
+ class AutismResearchBot:
43
+ def __init__(self, groq_api_key: str, index_path: str = "index.faiss"):
44
+ # Initialize the Groq LLM
45
+ self.llm = GroqLLM(
46
+ groq_api_key=groq_api_key,
47
+ model_name="llama-3.3-70b-versatile" # You can adjust the model as needed
48
+ )
49
+
50
+ # Load the FAISS index
51
+ self.embeddings = HuggingFaceEmbeddings(
52
+ model_name="pritamdeka/S-PubMedBert-MS-MARCO",
53
+ model_kwargs={'device': 'cpu'}
54
+ )
55
+
56
+ self.db = FAISS.load_local(index_path, self.embeddings, allow_dangerous_deserialization = True)
57
+
58
+ # Initialize memory
59
+ self.memory = ConversationBufferMemory(
60
+ memory_key="chat_history",
61
+ return_messages=True,
62
+ output_key = "answer"
63
+ )
64
+
65
+ # Create the RAG chain
66
+ self.qa_chain = self._create_qa_chain()
67
+
68
+ def _create_qa_chain(self):
69
+ # Define the prompt template
70
+ template = """You are an expert AI assistant specialized in autism research and diagnostics. You have access to a database of scientific papers, research documents, and diagnostic tools about autism. Use this knowledge to ask targeted questions, gather relevant information, and provide an accurate, evidence-based assessment of the type of autism the person may have. Finally, offer appropriate therapy recommendations.
71
+
72
+ Context from scientific papers use these context details only when you will at the end provide therapies don't dicusss these midway betwenn the conversation:
73
+
74
+ {context}
75
+
76
+ Chat History:
77
+ {chat_history}
78
+
79
+ Objective:
80
+
81
+ Ask a series of insightful, diagnostic questions to gather comprehensive information about the individual's or their child's behaviors, challenges, and strengths.
82
+ Analyze the responses given to these questions using knowledge from the provided research context.
83
+ Determine the type of autism the individual may have based on the gathered data.
84
+ Offer evidence-based therapy recommendations tailored to the identified type of autism.
85
+ Instructions:
86
+
87
+ Introduce yourself in the initial message. Please note not to reintroduce yourself in subsequent messages within the same chat.
88
+ Each question should be clear, accessible, and empathetic while maintaining scientific accuracy.
89
+ Ensure responses and questions demonstrate sensitivity to the diverse experiences of individuals with autism and their families.
90
+ Cite specific findings or conclusions from the research context where relevant.
91
+ Acknowledge any limitations or uncertainties in the research when analyzing responses.
92
+ Aim for conciseness in responses, ensuring clarity and brevity without losing essential details.
93
+ Initial Introduction:
94
+ β€œβ€"
95
+
96
+ Hello, I am an AI assistant specialized in autism research and diagnostics. I am here to gather some information to help provide an evidence-based assessment and recommend appropriate therapies.
97
+
98
+ β€œβ€"
99
+
100
+ Initial Diagnostic Question:
101
+ β€œβ€"
102
+
103
+ To begin, can you describe some of the behaviors or challenges that prompted you to seek this assessment?
104
+
105
+ β€œβ€"
106
+
107
+ Subsequent Questions: (Questions should follow based on the user's answers, aiming to gather necessary details concisely)
108
+
109
+ question :
110
+ {question}
111
+
112
+ Answer:"""
113
+
114
+ PROMPT = PromptTemplate(
115
+ template=template,
116
+ input_variables=["context", "chat_history", "question"]
117
+ )
118
+
119
+ # Create the chain
120
+ chain = ConversationalRetrievalChain.from_llm(
121
+ llm=self.llm,
122
+ chain_type="stuff",
123
+ retriever=self.db.as_retriever(
124
+ search_type="similarity",
125
+ search_kwargs={"k": 3}
126
+ ),
127
+ memory=self.memory,
128
+ combine_docs_chain_kwargs={
129
+ "prompt": PROMPT
130
+ },
131
+ # verbose = True,
132
+ return_source_documents=True
133
+ )
134
+
135
+ return chain
136
+
137
+ def answer_question(self, question: str):
138
+ """
139
+ Process a question and return the answer along with source documents
140
+ """
141
+ result = self.qa_chain({"question": question})
142
+
143
+ # Extract answer and sources
144
+ answer = result['answer']
145
+ sources = result['source_documents']
146
+
147
+ # Format sources for reference
148
+ source_info = []
149
+ for doc in sources:
150
+ source_info.append({
151
+ 'content': doc.page_content[:200] + "...",
152
+ 'metadata': doc.metadata
153
+ })
154
+
155
+ return {
156
+ 'answer': answer,
157
+ 'sources': source_info
158
+ }
159
+
160
+ # Example usage
161
+ if __name__ == "__main__":
162
+ groq_api_key = "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"
163
+
164
+ # Initialize the bot
165
+ bot = AutismResearchBot(groq_api_key=groq_api_key)
166
+
167
+ # Example question
168
+ # question = "What are the latest findings regarding sensory processing in autism?"
169
+ # response = bot.answer_question(question)
170
+ while(1):
171
+ print("*"*40)
172
+ print("*"*40)
173
+ print("*"*40)
174
+ question = input("Enter your question (or 'quit' to exit): ")
175
+ if question.lower() == 'quit':
176
+ break
177
+ response = bot.answer_question(question)
178
+ print("\nAnswer:")
179
+ print(response['answer'])
180
+ # print("\nSources used:")
181
+ # for source in response['sources']:
182
+ # print(f"\nSource metadata: {source['metadata']}")
183
+ # print(f"Content preview: {source['content']}")
184
+ # bot.answer_question
185
+ # Print response
186
+ # print("\nAnswer:")
187
+ # print(response['answer'])
188
+ # print("\nSources used:")
189
+ # for source in response['sources']:
190
+ # print(f"\nSource metadata: {source['metadata']}")
191
+ # print(f"Content preview: {source['content']}")
192
+
193
+
194