engrphoenix commited on
Commit
c3a4d93
·
verified ·
1 Parent(s): 5eedd0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -51
app.py CHANGED
@@ -1,16 +1,16 @@
1
- # app.py
2
  import os
 
 
 
 
 
3
  import streamlit as st
4
- from langchain.document_loaders import PyPDFLoader
5
- from langchain.embeddings import HuggingFaceEmbeddings
6
- from langchain.vectorstores import FAISS
7
  from langchain.chains import RetrievalQA
 
 
8
  from langchain.llms import HuggingFacePipeline
9
  from transformers import pipeline
10
  from groq import Groq
11
- import requests
12
- from PyPDF2 import PdfReader
13
- import io
14
 
15
  # Set up API key for Groq API
16
  #GROQ_API_KEY = "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"
@@ -27,58 +27,67 @@ def get_groq_client():
27
  groq_client = get_groq_client()
28
 
29
 
30
- # Predefined PDF link
31
- pdf_url = "https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link"
32
 
33
- def extract_text_from_pdf(pdf_url):
34
- """Extract text from a PDF file given its Google Drive shared link."""
35
- # Extract file ID from the Google Drive link
36
- file_id = pdf_url.split('/d/')[1].split('/view')[0]
37
- download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
38
- response = requests.get(download_url)
39
 
40
- if response.status_code == 200:
41
- pdf_content = io.BytesIO(response.content)
42
- reader = PdfReader(pdf_content)
43
- text = "\n".join([page.extract_text() for page in reader.pages])
44
- return text
45
- else:
46
- st.error("Failed to download PDF.")
47
- return ""
48
 
49
- # Streamlit Interface
50
- st.title("ASD Diagnosis Retrieval-Augmented Generation App")
 
 
51
 
52
- st.info("Processing predefined PDF...")
53
- extracted_text = extract_text_from_pdf(pdf_url)
54
 
55
- if extracted_text:
56
- st.success("Text extraction complete.")
 
 
 
57
 
58
- # Preprocess text for embeddings
59
- st.info("Generating embeddings...")
60
- embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
61
- embeddings = embeddings_model.embed_documents([extracted_text])
 
 
62
 
63
- # Store embeddings in FAISS
64
- st.info("Storing embeddings in FAISS...")
65
- faiss_index = FAISS.from_texts([extracted_text], embeddings_model)
66
 
67
- # Set up Hugging Face LLM pipeline
68
- st.info("Setting up RAG pipeline...")
69
- hf_pipeline = pipeline("text-generation", model="google/flan-t5-base", tokenizer="google/flan-t5-base")
70
- llm = HuggingFacePipeline(pipeline=hf_pipeline)
71
 
72
- retriever = faiss_index.as_retriever()
73
- qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
 
 
 
 
 
74
 
75
- # Query interface
76
- st.success("RAG pipeline ready.")
77
- user_query = st.text_input("Enter your query about ASD:")
78
 
79
- if user_query:
80
- st.info("Fetching response...")
81
- response = qa_chain.run(user_query)
82
- st.success(response)
83
- else:
84
- st.error("No text extracted from the PDF.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import requests
3
+ from io import BytesIO
4
+ from PyPDF2 import PdfReader
5
+ from sentence_transformers import SentenceTransformer
6
+ import faiss
7
  import streamlit as st
 
 
 
8
  from langchain.chains import RetrievalQA
9
+ from langchain.vectorstores import FAISS
10
+ from langchain.embeddings import HuggingFaceEmbeddings
11
  from langchain.llms import HuggingFacePipeline
12
  from transformers import pipeline
13
  from groq import Groq
 
 
 
14
 
15
  # Set up API key for Groq API
16
  #GROQ_API_KEY = "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"
 
27
  groq_client = get_groq_client()
28
 
29
 
 
 
30
 
 
 
 
 
 
 
31
 
32
+ def download_pdf(url):
33
+ response = requests.get(url)
34
+ response.raise_for_status()
35
+ return BytesIO(response.content)
 
 
 
 
36
 
37
+ def extract_text_from_pdf(pdf_data):
38
+ reader = PdfReader(pdf_data)
39
+ text = "\n".join(page.extract_text() for page in reader.pages if page.extract_text())
40
+ return text
41
 
42
+ def preprocess_text(text):
43
+ return " ".join(text.split())
44
 
45
+ def build_faiss_index(embeddings, texts):
46
+ index = faiss.IndexFlatL2(embeddings.embedding_dim)
47
+ text_store = FAISS(embeddings, index)
48
+ text_store.add_texts(texts)
49
+ return text_store
50
 
51
+ # URLs of ASD-related PDF documents
52
+ pdf_links = [
53
+ "https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link", # Replace X, Y, Z with actual URLs of ASD-related literature
54
+ "https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link",
55
+ "https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link"
56
+ ]
57
 
58
+ st.title("ASD Diagnosis and Therapy Chatbot")
 
 
59
 
60
+ st.markdown("This application assists in diagnosing types of ASD and recommends evidence-based therapies and treatments.")
 
 
 
61
 
62
+ with st.spinner("Downloading and extracting text from PDFs..."):
63
+ texts = []
64
+ for link in pdf_links:
65
+ pdf_data = download_pdf(link)
66
+ text = extract_text_from_pdf(pdf_data)
67
+ cleaned_text = preprocess_text(text)
68
+ texts.append(cleaned_text)
69
 
70
+ with st.spinner("Generating embeddings..."):
71
+ embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
72
+ text_store = build_faiss_index(embeddings_model, texts)
73
 
74
+ with st.spinner("Setting up the RAG pipeline..."):
75
+ hf_pipeline = pipeline("text-generation", model="gpt-2") # Replace with a model optimized for medical text, if available
76
+ llm = HuggingFacePipeline(pipeline=hf_pipeline)
77
+ qa_chain = RetrievalQA(llm=llm, retriever=text_store.as_retriever())
78
+
79
+ query = st.text_input("Ask a question about ASD diagnosis, types, or therapies:")
80
+ if query:
81
+ with st.spinner("Processing your query..."):
82
+ answer = qa_chain.run(query)
83
+ st.success("Answer:")
84
+ st.write(answer)
85
+
86
+ st.markdown("---")
87
+ st.markdown("### Example Queries:")
88
+ st.markdown("- What type of ASD does an individual with sensory issues have?")
89
+ st.markdown("- What therapies are recommended for social communication challenges?")
90
+ st.markdown("- What treatments are supported by clinical guidelines for repetitive behaviors?")
91
+
92
+ st.markdown("---")
93
+ st.markdown("Powered by Streamlit, Hugging Face, and LangChain")