File size: 3,456 Bytes
19ceabc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d6ed89
 
 
 
 
 
 
 
19ceabc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d6ed89
 
 
 
 
 
19ceabc
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import streamlit as st
from PyPDF2 import PdfReader
from sentence_transformers import SentenceTransformer
import faiss
import matplotlib.pyplot as plt
import numpy as np
from groq import Groq

GROQ_API_KEY = "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"
client = Groq(api_key=GROQ_API_KEY)

# Initialize Embedding Model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Initialize FAISS Index
embedding_dim = 384  # Dimensionality of 'all-MiniLM-L6-v2'
faiss_index = faiss.IndexFlatL2(embedding_dim)

# Store Metadata
metadata_store = []

def extract_text_from_pdf(pdf_file):
    pdf_reader = PdfReader(pdf_file)
    text = ""
    for page in pdf_reader.pages:
        text += page.extract_text()
    return text

def chunk_text(text, chunk_size=500):
    words = text.split()
    return [' '.join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]

def generate_embeddings(chunks):
    return embedding_model.encode(chunks)

def store_embeddings(embeddings, metadata):
    faiss_index.add(np.array(embeddings))
    metadata_store.extend(metadata)

def retrieve_relevant_chunks(query, k=5):
    query_embedding = embedding_model.encode([query])
    distances, indices = faiss_index.search(query_embedding, k)
    
    # Safeguard: Ensure indices are within bounds of metadata_store
    valid_results = [
        (metadata_store[i], distances[0][j])
        for j, i in enumerate(indices[0])
        if i < len(metadata_store)
    ]
    return valid_results

def ask_groq_api(question, context):
    chat_completion = client.chat.completions.create(
        messages=[
            {"role": "user", "content": f"{context}\n\n{question}"}
        ],
        model="llama3-8b-8192"
    )
    return chat_completion.choices[0].message.content

# Streamlit App
st.title("RAG-Based Research Paper Analyzer")

uploaded_files = st.file_uploader("Upload PDF Files", accept_multiple_files=True, type="pdf")

if uploaded_files:
    all_chunks = []
    all_metadata = []
    
    for uploaded_file in uploaded_files:
        text = extract_text_from_pdf(uploaded_file)
        chunks = chunk_text(text)
        embeddings = generate_embeddings(chunks)
        metadata = [{"chunk": chunk, "file_name": uploaded_file.name} for chunk in chunks]
        store_embeddings(embeddings, metadata)
        all_chunks.extend(chunks)
        all_metadata.extend(metadata)
    
    st.success("Files uploaded and processed successfully!")

    if st.button("View Topic Summaries"):
        for chunk in all_chunks[:3]:
            st.write(chunk)

    user_question = st.text_input("Ask a question about the uploaded papers:")
    if user_question:
        relevant_chunks = retrieve_relevant_chunks(user_question)
        if relevant_chunks:
            context = "\n\n".join([chunk['chunk'] for chunk, _ in relevant_chunks])
            answer = ask_groq_api(user_question, context)
            st.write("**Answer:**", answer)
        else:
            st.write("No relevant sections found for your question.")

    if st.button("Generate Scatter Plot"):
        st.write("Generating scatter plot for methods vs. results...")
        # Example scatter plot (replace with real data)
        x = np.random.rand(10)
        y = np.random.rand(10)
        plt.scatter(x, y)
        plt.xlabel("Methods")
        plt.ylabel("Results")
        st.pyplot(plt)

    st.text_area("Annotate Your Insights:", height=100, key="annotations")