ADS / app.py
engrphoenix's picture
Update app.py
c3a4d93 verified
import os
import requests
from io import BytesIO
from PyPDF2 import PdfReader
from sentence_transformers import SentenceTransformer
import faiss
import streamlit as st
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
from groq import Groq
# Set up API key for Groq API
#GROQ_API_KEY = "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"
#os.environ["GROQ_API_KEY"] = GROQ_API_KEY
# Initialize Groq API client
#client = Groq(api_key=GROQ_API_KEY)
def get_groq_client():
api_key = os.getenv("groq_api_key")
if not api_key:
raise ValueError("Groq API key not found in environment variables.")
return Groq(api_key=api_key)
groq_client = get_groq_client()
def download_pdf(url):
response = requests.get(url)
response.raise_for_status()
return BytesIO(response.content)
def extract_text_from_pdf(pdf_data):
reader = PdfReader(pdf_data)
text = "\n".join(page.extract_text() for page in reader.pages if page.extract_text())
return text
def preprocess_text(text):
return " ".join(text.split())
def build_faiss_index(embeddings, texts):
index = faiss.IndexFlatL2(embeddings.embedding_dim)
text_store = FAISS(embeddings, index)
text_store.add_texts(texts)
return text_store
# URLs of ASD-related PDF documents
pdf_links = [
"https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link", # Replace X, Y, Z with actual URLs of ASD-related literature
"https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link",
"https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link"
]
st.title("ASD Diagnosis and Therapy Chatbot")
st.markdown("This application assists in diagnosing types of ASD and recommends evidence-based therapies and treatments.")
with st.spinner("Downloading and extracting text from PDFs..."):
texts = []
for link in pdf_links:
pdf_data = download_pdf(link)
text = extract_text_from_pdf(pdf_data)
cleaned_text = preprocess_text(text)
texts.append(cleaned_text)
with st.spinner("Generating embeddings..."):
embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
text_store = build_faiss_index(embeddings_model, texts)
with st.spinner("Setting up the RAG pipeline..."):
hf_pipeline = pipeline("text-generation", model="gpt-2") # Replace with a model optimized for medical text, if available
llm = HuggingFacePipeline(pipeline=hf_pipeline)
qa_chain = RetrievalQA(llm=llm, retriever=text_store.as_retriever())
query = st.text_input("Ask a question about ASD diagnosis, types, or therapies:")
if query:
with st.spinner("Processing your query..."):
answer = qa_chain.run(query)
st.success("Answer:")
st.write(answer)
st.markdown("---")
st.markdown("### Example Queries:")
st.markdown("- What type of ASD does an individual with sensory issues have?")
st.markdown("- What therapies are recommended for social communication challenges?")
st.markdown("- What treatments are supported by clinical guidelines for repetitive behaviors?")
st.markdown("---")
st.markdown("Powered by Streamlit, Hugging Face, and LangChain")