RAG-REV-01 / app.py
muhammadshaheryar's picture
Update app.py
e2c01c3 verified
import faiss
import fitz # PyMuPDF
import pandas as pd
from transformers import DPRQuestionEncoder, DPRContextEncoder, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
from docx import Document
import streamlit as st
import os
from bs4 import BeautifulSoup
# Initialize models and FAISS index
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
index = faiss.IndexFlatL2(384) # 384-dimensional embeddings for this model
document_texts = []
document_mapping = {}
# Function to load and convert files to text
def load_text_from_files(file_path):
if file_path.endswith(".pdf"):
return extract_text_from_pdf(file_path)
elif file_path.endswith(".docx"):
return extract_text_from_docx(file_path)
elif file_path.endswith(".csv"):
return extract_text_from_csv(file_path)
elif file_path.endswith(".xlsx"):
return extract_text_from_xlsx(file_path)
elif file_path.endswith(".html"):
return extract_text_from_html(file_path)
else:
return ""
def extract_text_from_pdf(file_path):
text = ""
with fitz.open(file_path) as doc:
for page in doc:
text += page.get_text()
return text
def extract_text_from_docx(file_path):
doc = Document(file_path)
return " ".join([para.text for para in doc.paragraphs])
def extract_text_from_csv(file_path):
df = pd.read_csv(file_path)
return " ".join(df.apply(lambda row: " ".join(map(str, row)), axis=1))
def extract_text_from_xlsx(file_path):
df = pd.read_excel(file_path)
return " ".join(df.apply(lambda row: " ".join(map(str, row)), axis=1))
def extract_text_from_html(file_path):
with open(file_path, "r") as file:
soup = BeautifulSoup(file, "html.parser")
return soup.get_text()
# Indexing uploaded documents
def index_documents(uploaded_files):
global document_texts, document_mapping
for file in uploaded_files:
file_path = os.path.join("/content/temp/", file.name)
with open(file_path, "wb") as f:
f.write(file.read())
text = load_text_from_files(file_path)
if text:
document_texts.append(text)
embeddings = embedding_model.encode([text])
index.add(embeddings)
document_mapping[len(document_texts) - 1] = text
# Load retrieval and generation models
question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
question_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator = pipeline("text-generation", model="gpt2")
# RAG pipeline function
def retrieve_and_generate(query):
query_embeddings = embedding_model.encode([query])
_, I = index.search(query_embeddings, k=5) # Top-5 relevant contexts
retrieved_texts = [document_mapping[idx] for idx in I[0]]
context = " ".join(retrieved_texts)
response = generator(f"{query} [SEP] {context}", max_length=150, num_return_sequences=1)
return response[0]['generated_text']
# Streamlit interface
st.title("Electrical Engineering RAG System")
st.write("Upload your files, ask questions, and get responses based on your data.")
uploaded_files = st.file_uploader("Upload Documents", accept_multiple_files=True, type=["pdf", "docx", "csv", "xlsx", "html"])
if uploaded_files:
index_documents(uploaded_files)
st.write("Files uploaded successfully! You can now ask questions.")
user_query = st.text_input("Ask a question:")
if user_query:
response = retrieve_and_generate(user_query)
st.write("Answer:", response)