File size: 2,672 Bytes
1b59b58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# app.py
import os
import streamlit as st
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
from groq import Groq
import requests
from PyPDF2 import PdfReader
import io

# Set up API key for Groq API
GROQ_API_KEY = "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"
os.environ["GROQ_API_KEY"] = GROQ_API_KEY

# Initialize Groq API client
client = Groq(api_key=GROQ_API_KEY)

# Predefined PDF link
pdf_url = "https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link"

def extract_text_from_pdf(pdf_url):
    """Extract text from a PDF file given its Google Drive shared link."""
    # Extract file ID from the Google Drive link
    file_id = pdf_url.split('/d/')[1].split('/view')[0]
    download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
    response = requests.get(download_url)

    if response.status_code == 200:
        pdf_content = io.BytesIO(response.content)
        reader = PdfReader(pdf_content)
        text = "\n".join([page.extract_text() for page in reader.pages])
        return text
    else:
        st.error("Failed to download PDF.")
        return ""

# Streamlit Interface
st.title("ASD Diagnosis Retrieval-Augmented Generation App")

st.info("Processing predefined PDF...")
extracted_text = extract_text_from_pdf(pdf_url)

if extracted_text:
    st.success("Text extraction complete.")

    # Preprocess text for embeddings
    st.info("Generating embeddings...")
    embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    embeddings = embeddings_model.embed_documents([extracted_text])

    # Store embeddings in FAISS
    st.info("Storing embeddings in FAISS...")
    faiss_index = FAISS.from_texts([extracted_text], embeddings_model)

    # Set up Hugging Face LLM pipeline
    st.info("Setting up RAG pipeline...")
    hf_pipeline = pipeline("text-generation", model="google/flan-t5-base", tokenizer="google/flan-t5-base")
    llm = HuggingFacePipeline(pipeline=hf_pipeline)

    retriever = faiss_index.as_retriever()
    qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)

    # Query interface
    st.success("RAG pipeline ready.")
    user_query = st.text_input("Enter your query about ASD:")

    if user_query:
        st.info("Fetching response...")
        response = qa_chain.run(user_query)
        st.success(response)
else:
    st.error("No text extracted from the PDF.")