Spaces:
Sleeping
Sleeping
import os | |
import requests | |
from io import BytesIO | |
from PyPDF2 import PdfReader | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import streamlit as st | |
from langchain.chains import RetrievalQA | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.llms import HuggingFacePipeline | |
from transformers import pipeline | |
from groq import Groq | |
# Set up API key for Groq API | |
#GROQ_API_KEY = "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" | |
#os.environ["GROQ_API_KEY"] = GROQ_API_KEY | |
# Initialize Groq API client | |
#client = Groq(api_key=GROQ_API_KEY) | |
def get_groq_client(): | |
api_key = os.getenv("groq_api_key") | |
if not api_key: | |
raise ValueError("Groq API key not found in environment variables.") | |
return Groq(api_key=api_key) | |
groq_client = get_groq_client() | |
def download_pdf(url): | |
response = requests.get(url) | |
response.raise_for_status() | |
return BytesIO(response.content) | |
def extract_text_from_pdf(pdf_data): | |
reader = PdfReader(pdf_data) | |
text = "\n".join(page.extract_text() for page in reader.pages if page.extract_text()) | |
return text | |
def preprocess_text(text): | |
return " ".join(text.split()) | |
def build_faiss_index(embeddings, texts): | |
index = faiss.IndexFlatL2(embeddings.embedding_dim) | |
text_store = FAISS(embeddings, index) | |
text_store.add_texts(texts) | |
return text_store | |
# URLs of ASD-related PDF documents | |
pdf_links = [ | |
"https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link", # Replace X, Y, Z with actual URLs of ASD-related literature | |
"https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link", | |
"https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link" | |
] | |
st.title("ASD Diagnosis and Therapy Chatbot") | |
st.markdown("This application assists in diagnosing types of ASD and recommends evidence-based therapies and treatments.") | |
with st.spinner("Downloading and extracting text from PDFs..."): | |
texts = [] | |
for link in pdf_links: | |
pdf_data = download_pdf(link) | |
text = extract_text_from_pdf(pdf_data) | |
cleaned_text = preprocess_text(text) | |
texts.append(cleaned_text) | |
with st.spinner("Generating embeddings..."): | |
embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
text_store = build_faiss_index(embeddings_model, texts) | |
with st.spinner("Setting up the RAG pipeline..."): | |
hf_pipeline = pipeline("text-generation", model="gpt-2") # Replace with a model optimized for medical text, if available | |
llm = HuggingFacePipeline(pipeline=hf_pipeline) | |
qa_chain = RetrievalQA(llm=llm, retriever=text_store.as_retriever()) | |
query = st.text_input("Ask a question about ASD diagnosis, types, or therapies:") | |
if query: | |
with st.spinner("Processing your query..."): | |
answer = qa_chain.run(query) | |
st.success("Answer:") | |
st.write(answer) | |
st.markdown("---") | |
st.markdown("### Example Queries:") | |
st.markdown("- What type of ASD does an individual with sensory issues have?") | |
st.markdown("- What therapies are recommended for social communication challenges?") | |
st.markdown("- What treatments are supported by clinical guidelines for repetitive behaviors?") | |
st.markdown("---") | |
st.markdown("Powered by Streamlit, Hugging Face, and LangChain") | |