EngrHamidUllah commited on
Commit
7388333
Β·
verified Β·
1 Parent(s): f0b72b2

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. app.py +155 -0
  3. autism_chatbot.py +158 -0
  4. index.faiss +3 -0
  5. index.pkl +3 -0
  6. requirements.txt +6 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ index.faiss filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+ from autism_chatbot import *
4
+
5
+ class StreamHandler:
6
+ def __init__(self, placeholder):
7
+ self.text = ""
8
+ self.text_container = placeholder
9
+
10
+ def append_text(self, text: str) -> None:
11
+ self.text += text
12
+ self.text_container.markdown(self.text)
13
+
14
+ class StreamingGroqLLM(GroqLLM):
15
+ stream_handler: Any = Field(None, description="Stream handler for real-time output")
16
+
17
+ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
18
+ completion = self.client.chat.completions.create(
19
+ messages=[{"role": "user", "content": prompt}],
20
+ model=self.model_name,
21
+ stream=True,
22
+ **kwargs
23
+ )
24
+
25
+ collected_chunks = []
26
+ collected_messages = []
27
+
28
+ for chunk in completion:
29
+ chunk_message = chunk.choices[0].delta.content
30
+ if chunk_message is not None:
31
+ collected_chunks.append(chunk_message)
32
+ collected_messages.append(chunk_message)
33
+ if self.stream_handler:
34
+ self.stream_handler.append_text(chunk_message)
35
+ time.sleep(0.05)
36
+
37
+ return ''.join(collected_messages)
38
+
39
+ class StreamingAutismResearchBot(AutismResearchBot):
40
+ def __init__(self, groq_api_key: str, stream_handler: StreamHandler, index_path: str = "index.faiss"):
41
+ self.llm = StreamingGroqLLM(
42
+ groq_api_key=groq_api_key,
43
+ model_name="llama-3.3-70b-versatile",
44
+ stream_handler=stream_handler
45
+ )
46
+
47
+ self.embeddings = HuggingFaceEmbeddings(
48
+ model_name="pritamdeka/S-PubMedBert-MS-MARCO",
49
+ model_kwargs={'device': 'cpu'}
50
+ )
51
+ self.db = FAISS.load_local("./", self.embeddings, allow_dangerous_deserialization=True)
52
+
53
+ self.memory = ConversationBufferMemory(
54
+ memory_key="chat_history",
55
+ return_messages=True,
56
+ output_key="answer"
57
+ )
58
+
59
+ self.qa_chain = self._create_qa_chain()
60
+
61
+ def main():
62
+ # Page configuration
63
+ st.set_page_config(
64
+ page_title="Autism Research Assistant",
65
+ page_icon="🧩",
66
+ layout="wide"
67
+ )
68
+
69
+ # Add custom CSS
70
+ st.markdown("""
71
+ <style>
72
+ .stApp {
73
+ max-width: 1200px;
74
+ margin: 0 auto;
75
+ }
76
+ .stMarkdown {
77
+ font-size: 16px;
78
+ }
79
+ .chat-message {
80
+ padding: 1rem;
81
+ border-radius: 0.5rem;
82
+ margin-bottom: 1rem;
83
+ }
84
+ .timestamp {
85
+ font-size: 0.8em;
86
+ color: #666;
87
+ }
88
+ </style>
89
+ """, unsafe_allow_html=True)
90
+
91
+ # Header
92
+ st.title("🧩 Autism Research Assistant")
93
+ st.markdown("""
94
+ Welcome to your AI-powered autism research assistant. I'm here to provide evidence-based
95
+ assessments and therapy recommendations based on scientific research.
96
+ """)
97
+
98
+ # Initialize session state
99
+ if 'messages' not in st.session_state:
100
+ st.session_state.messages = [
101
+ {"role": "assistant", "content": "Hello! I'm your autism research assistant. How can I help you today?"}
102
+ ]
103
+
104
+ # Initialize bot
105
+ if 'bot' not in st.session_state:
106
+ st.session_state.stream_container = None
107
+ st.session_state.bot = None
108
+
109
+ # Display chat messages
110
+ for message in st.session_state.messages:
111
+ with st.chat_message(message["role"]):
112
+ st.write(f"{message['content']}")
113
+ st.caption(f"{time.strftime('%I:%M %p')}")
114
+
115
+ # Chat input
116
+ if prompt := st.chat_input("Type your message here..."):
117
+ # Display user message
118
+ with st.chat_message("user"):
119
+ st.write(prompt)
120
+ st.caption(f"{time.strftime('%I:%M %p')}")
121
+
122
+ # Add to session state
123
+ st.session_state.messages.append({"role": "user", "content": prompt})
124
+
125
+ # Create a new chat message container for the assistant's response
126
+ assistant_message = st.chat_message("assistant")
127
+ with assistant_message:
128
+ # Create a placeholder for the streaming text
129
+ stream_placeholder = st.empty()
130
+
131
+ # Initialize the bot with the new stream handler if not already initialized
132
+ if st.session_state.bot is None:
133
+ stream_handler = StreamHandler(stream_placeholder)
134
+ st.session_state.bot = StreamingAutismResearchBot(
135
+ groq_api_key= os.environ.get("GROQ_API_KEY"),
136
+ stream_handler=stream_handler,
137
+ )
138
+ else:
139
+ # Update the stream handler with the new placeholder
140
+ st.session_state.bot.llm.stream_handler.text = ""
141
+ st.session_state.bot.llm.stream_handler.text_container = stream_placeholder
142
+
143
+ # Generate response
144
+ response = st.session_state.bot.answer_question(prompt)
145
+
146
+ # Clear the streaming placeholder and display the final message
147
+ stream_placeholder.empty()
148
+ st.write(response['answer'])
149
+ st.caption(f"{time.strftime('%I:%M %p')}")
150
+
151
+ # Add bot response to session state
152
+ st.session_state.messages.append({"role": "assistant", "content": response['answer']})
153
+
154
+ if __name__ == "__main__":
155
+ main()
autism_chatbot.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import ConversationalRetrievalChain
2
+ from langchain.memory import ConversationBufferMemory
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain.llms.base import LLM
7
+ from groq import Groq
8
+ from typing import Any, List, Optional, Dict
9
+ from pydantic import Field, BaseModel
10
+ import os
11
+
12
+
13
+ class GroqLLM(LLM, BaseModel):
14
+ groq_api_key: str = Field(..., description="Groq API Key")
15
+ model_name: str = Field(default="llama-3.3-70b-versatile", description="Model name to use")
16
+ client: Optional[Any] = None
17
+
18
+ def __init__(self, **data):
19
+ super().__init__(**data)
20
+ self.client = Groq(api_key=self.groq_api_key)
21
+
22
+ @property
23
+ def _llm_type(self) -> str:
24
+ return "groq"
25
+
26
+ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
27
+ completion = self.client.chat.completions.create(
28
+ messages=[{"role": "user", "content": prompt}],
29
+ model=self.model_name,
30
+ **kwargs
31
+ )
32
+ return completion.choices[0].message.content
33
+
34
+ @property
35
+ def _identifying_params(self) -> Dict[str, Any]:
36
+ """Get the identifying parameters."""
37
+ return {
38
+ "model_name": self.model_name
39
+ }
40
+
41
+
42
+ class AutismResearchBot:
43
+ def __init__(self, groq_api_key: str, index_path: str = "index.faiss"):
44
+ # Initialize the Groq LLM
45
+ self.llm = GroqLLM(
46
+ groq_api_key=groq_api_key,
47
+ model_name="llama-3.3-70b-versatile" # You can adjust the model as needed
48
+ )
49
+
50
+ # Load the FAISS index
51
+ self.embeddings = HuggingFaceEmbeddings(
52
+ model_name="pritamdeka/S-PubMedBert-MS-MARCO",
53
+ model_kwargs={'device': 'cpu'}
54
+ )
55
+
56
+ self.db = FAISS.load_local("./", self.embeddings, allow_dangerous_deserialization = True)
57
+
58
+ # Initialize memory
59
+ self.memory = ConversationBufferMemory(
60
+ memory_key="chat_history",
61
+ return_messages=True,
62
+ output_key = "answer"
63
+ )
64
+
65
+ # Create the RAG chain
66
+ self.qa_chain = self._create_qa_chain()
67
+
68
+ def _create_qa_chain(self):
69
+ # Define the prompt template
70
+ template = """You are an expert AI assistant specialized in autism research and diagnostics. You have access to a database of scientific papers, research documents, and diagnostic tools about autism. Use this knowledge to ask targeted questions, gather relevant information, and provide an accurate, evidence-based assessment of the type of autism the person may have. Finally, offer appropriate therapy recommendations.
71
+
72
+ Context from scientific papers use these context details only when you will at the end provide therapies don't dicusss these midway betwenn the conversation:
73
+
74
+ {context}
75
+
76
+ Chat History:
77
+ {chat_history}
78
+
79
+ Objective:
80
+
81
+ Ask a series of insightful, diagnostic questions to gather comprehensive information about the individual's or their child's behaviors, challenges, and strengths.
82
+ Analyze the responses given to these questions using knowledge from the provided research context.
83
+ Determine the type of autism the individual may have based on the gathered data.
84
+ Offer evidence-based therapy recommendations tailored to the identified type of autism.
85
+ Instructions:
86
+
87
+ Introduce yourself in the initial message. Please note not to reintroduce yourself in subsequent messages within the same chat.
88
+ Each question should be clear, accessible, and empathetic while maintaining scientific accuracy.
89
+ Ensure responses and questions demonstrate sensitivity to the diverse experiences of individuals with autism and their families.
90
+ Cite specific findings or conclusions from the research context where relevant.
91
+ Acknowledge any limitations or uncertainties in the research when analyzing responses.
92
+ Aim for conciseness in responses, ensuring clarity and brevity without losing essential details.
93
+ Initial Introduction:
94
+ β€œβ€"
95
+
96
+ Hello, I am an AI assistant specialized in autism research and diagnostics. I am here to gather some information to help provide an evidence-based assessment and recommend appropriate therapies.
97
+
98
+ β€œβ€"
99
+
100
+ Initial Diagnostic Question:
101
+ β€œβ€"
102
+
103
+ To begin, can you describe some of the behaviors or challenges that prompted you to seek this assessment?
104
+
105
+ β€œβ€"
106
+
107
+ Subsequent Questions: (Questions should follow based on the user's answers, aiming to gather necessary details concisely)
108
+
109
+ question :
110
+ {question}
111
+
112
+ Answer:"""
113
+
114
+ PROMPT = PromptTemplate(
115
+ template=template,
116
+ input_variables=["context", "chat_history", "question"]
117
+ )
118
+
119
+ # Create the chain
120
+ chain = ConversationalRetrievalChain.from_llm(
121
+ llm=self.llm,
122
+ chain_type="stuff",
123
+ retriever=self.db.as_retriever(
124
+ search_type="similarity",
125
+ search_kwargs={"k": 3}
126
+ ),
127
+ memory=self.memory,
128
+ combine_docs_chain_kwargs={
129
+ "prompt": PROMPT
130
+ },
131
+ # verbose = True,
132
+ return_source_documents=True
133
+ )
134
+
135
+ return chain
136
+
137
+ def answer_question(self, question: str):
138
+ """
139
+ Process a question and return the answer along with source documents
140
+ """
141
+ result = self.qa_chain({"question": question})
142
+
143
+ # Extract answer and sources
144
+ answer = result['answer']
145
+ sources = result['source_documents']
146
+
147
+ # Format sources for reference
148
+ source_info = []
149
+ for doc in sources:
150
+ source_info.append({
151
+ 'content': doc.page_content[:200] + "...",
152
+ 'metadata': doc.metadata
153
+ })
154
+
155
+ return {
156
+ 'answer': answer,
157
+ 'sources': source_info
158
+ }
index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2f5db5d800828252b7fee1fb83824c73aa459443ccf0e66e148a928d103c565
3
+ size 7397421
index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b33f5ef3db502ad9564eb1d66b04e574cbc623bf8a8d6a1f3963db39f3398d15
3
+ size 5847437
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-community
3
+ groq
4
+ sentence-transformers
5
+ streamlit
6
+ faiss-cpu