Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
-
# app.py
|
2 |
import os
|
|
|
|
|
|
|
|
|
|
|
3 |
import streamlit as st
|
4 |
-
from langchain.document_loaders import PyPDFLoader
|
5 |
-
from langchain.embeddings import HuggingFaceEmbeddings
|
6 |
-
from langchain.vectorstores import FAISS
|
7 |
from langchain.chains import RetrievalQA
|
|
|
|
|
8 |
from langchain.llms import HuggingFacePipeline
|
9 |
from transformers import pipeline
|
10 |
from groq import Groq
|
11 |
-
import requests
|
12 |
-
from PyPDF2 import PdfReader
|
13 |
-
import io
|
14 |
|
15 |
# Set up API key for Groq API
|
16 |
#GROQ_API_KEY = "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"
|
@@ -27,58 +27,67 @@ def get_groq_client():
|
|
27 |
groq_client = get_groq_client()
|
28 |
|
29 |
|
30 |
-
# Predefined PDF link
|
31 |
-
pdf_url = "https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link"
|
32 |
|
33 |
-
def extract_text_from_pdf(pdf_url):
|
34 |
-
"""Extract text from a PDF file given its Google Drive shared link."""
|
35 |
-
# Extract file ID from the Google Drive link
|
36 |
-
file_id = pdf_url.split('/d/')[1].split('/view')[0]
|
37 |
-
download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
|
38 |
-
response = requests.get(download_url)
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
return text
|
45 |
-
else:
|
46 |
-
st.error("Failed to download PDF.")
|
47 |
-
return ""
|
48 |
|
49 |
-
|
50 |
-
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
st.info("Storing embeddings in FAISS...")
|
65 |
-
faiss_index = FAISS.from_texts([extracted_text], embeddings_model)
|
66 |
|
67 |
-
|
68 |
-
st.info("Setting up RAG pipeline...")
|
69 |
-
hf_pipeline = pipeline("text-generation", model="google/flan-t5-base", tokenizer="google/flan-t5-base")
|
70 |
-
llm = HuggingFacePipeline(pipeline=hf_pipeline)
|
71 |
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import requests
|
3 |
+
from io import BytesIO
|
4 |
+
from PyPDF2 import PdfReader
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
import faiss
|
7 |
import streamlit as st
|
|
|
|
|
|
|
8 |
from langchain.chains import RetrievalQA
|
9 |
+
from langchain.vectorstores import FAISS
|
10 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
11 |
from langchain.llms import HuggingFacePipeline
|
12 |
from transformers import pipeline
|
13 |
from groq import Groq
|
|
|
|
|
|
|
14 |
|
15 |
# Set up API key for Groq API
|
16 |
#GROQ_API_KEY = "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"
|
|
|
27 |
groq_client = get_groq_client()
|
28 |
|
29 |
|
|
|
|
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
def download_pdf(url):
|
33 |
+
response = requests.get(url)
|
34 |
+
response.raise_for_status()
|
35 |
+
return BytesIO(response.content)
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
def extract_text_from_pdf(pdf_data):
|
38 |
+
reader = PdfReader(pdf_data)
|
39 |
+
text = "\n".join(page.extract_text() for page in reader.pages if page.extract_text())
|
40 |
+
return text
|
41 |
|
42 |
+
def preprocess_text(text):
|
43 |
+
return " ".join(text.split())
|
44 |
|
45 |
+
def build_faiss_index(embeddings, texts):
|
46 |
+
index = faiss.IndexFlatL2(embeddings.embedding_dim)
|
47 |
+
text_store = FAISS(embeddings, index)
|
48 |
+
text_store.add_texts(texts)
|
49 |
+
return text_store
|
50 |
|
51 |
+
# URLs of ASD-related PDF documents
|
52 |
+
pdf_links = [
|
53 |
+
"https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link", # Replace X, Y, Z with actual URLs of ASD-related literature
|
54 |
+
"https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link",
|
55 |
+
"https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link"
|
56 |
+
]
|
57 |
|
58 |
+
st.title("ASD Diagnosis and Therapy Chatbot")
|
|
|
|
|
59 |
|
60 |
+
st.markdown("This application assists in diagnosing types of ASD and recommends evidence-based therapies and treatments.")
|
|
|
|
|
|
|
61 |
|
62 |
+
with st.spinner("Downloading and extracting text from PDFs..."):
|
63 |
+
texts = []
|
64 |
+
for link in pdf_links:
|
65 |
+
pdf_data = download_pdf(link)
|
66 |
+
text = extract_text_from_pdf(pdf_data)
|
67 |
+
cleaned_text = preprocess_text(text)
|
68 |
+
texts.append(cleaned_text)
|
69 |
|
70 |
+
with st.spinner("Generating embeddings..."):
|
71 |
+
embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
72 |
+
text_store = build_faiss_index(embeddings_model, texts)
|
73 |
|
74 |
+
with st.spinner("Setting up the RAG pipeline..."):
|
75 |
+
hf_pipeline = pipeline("text-generation", model="gpt-2") # Replace with a model optimized for medical text, if available
|
76 |
+
llm = HuggingFacePipeline(pipeline=hf_pipeline)
|
77 |
+
qa_chain = RetrievalQA(llm=llm, retriever=text_store.as_retriever())
|
78 |
+
|
79 |
+
query = st.text_input("Ask a question about ASD diagnosis, types, or therapies:")
|
80 |
+
if query:
|
81 |
+
with st.spinner("Processing your query..."):
|
82 |
+
answer = qa_chain.run(query)
|
83 |
+
st.success("Answer:")
|
84 |
+
st.write(answer)
|
85 |
+
|
86 |
+
st.markdown("---")
|
87 |
+
st.markdown("### Example Queries:")
|
88 |
+
st.markdown("- What type of ASD does an individual with sensory issues have?")
|
89 |
+
st.markdown("- What therapies are recommended for social communication challenges?")
|
90 |
+
st.markdown("- What treatments are supported by clinical guidelines for repetitive behaviors?")
|
91 |
+
|
92 |
+
st.markdown("---")
|
93 |
+
st.markdown("Powered by Streamlit, Hugging Face, and LangChain")
|