File size: 2,902 Bytes
1b59b58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0431e9d
 
1b59b58
 
0431e9d
 
 
 
 
 
 
 
 
1b59b58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# app.py
import os
import streamlit as st
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
from groq import Groq
import requests
from PyPDF2 import PdfReader
import io

# Set up API key for Groq API
#GROQ_API_KEY = "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"
#os.environ["GROQ_API_KEY"] = GROQ_API_KEY

# Initialize Groq API client
#client = Groq(api_key=GROQ_API_KEY)
def get_groq_client():
    api_key = os.getenv("groq_api_key")
    if not api_key:
        raise ValueError("Groq API key not found in environment variables.")
    return Groq(api_key=api_key)

groq_client = get_groq_client()


# Predefined PDF link
pdf_url = "https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link"

def extract_text_from_pdf(pdf_url):
    """Extract text from a PDF file given its Google Drive shared link."""
    # Extract file ID from the Google Drive link
    file_id = pdf_url.split('/d/')[1].split('/view')[0]
    download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
    response = requests.get(download_url)

    if response.status_code == 200:
        pdf_content = io.BytesIO(response.content)
        reader = PdfReader(pdf_content)
        text = "\n".join([page.extract_text() for page in reader.pages])
        return text
    else:
        st.error("Failed to download PDF.")
        return ""

# Streamlit Interface
st.title("ASD Diagnosis Retrieval-Augmented Generation App")

st.info("Processing predefined PDF...")
extracted_text = extract_text_from_pdf(pdf_url)

if extracted_text:
    st.success("Text extraction complete.")

    # Preprocess text for embeddings
    st.info("Generating embeddings...")
    embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    embeddings = embeddings_model.embed_documents([extracted_text])

    # Store embeddings in FAISS
    st.info("Storing embeddings in FAISS...")
    faiss_index = FAISS.from_texts([extracted_text], embeddings_model)

    # Set up Hugging Face LLM pipeline
    st.info("Setting up RAG pipeline...")
    hf_pipeline = pipeline("text-generation", model="google/flan-t5-base", tokenizer="google/flan-t5-base")
    llm = HuggingFacePipeline(pipeline=hf_pipeline)

    retriever = faiss_index.as_retriever()
    qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)

    # Query interface
    st.success("RAG pipeline ready.")
    user_query = st.text_input("Enter your query about ASD:")

    if user_query:
        st.info("Fetching response...")
        response = qa_chain.run(user_query)
        st.success(response)
else:
    st.error("No text extracted from the PDF.")