Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
import pdfplumber | |
import logging | |
import pandas as pd | |
import docx | |
import pickle | |
import os | |
from hashlib import sha256 | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize QA pipeline with a pre-trained RoBERTa QA model | |
def init_qa_model(): | |
try: | |
logger.info("Initializing QA model...") | |
qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
logger.info("QA model loaded successfully.") | |
return qa_pipeline | |
except Exception as e: | |
logger.error(f"Error loading QA model: {e}") | |
st.error(f"Error loading the QA model: {e}") | |
return None | |
# Function to extract text from PDF | |
def extract_text_from_pdf(pdf_file): | |
try: | |
with pdfplumber.open(pdf_file) as pdf: | |
text = '' | |
for page in pdf.pages: | |
page_text = page.extract_text() | |
if page_text: | |
text += page_text | |
return text or "No text found in the PDF." | |
except Exception as e: | |
logger.error(f"Error extracting text from PDF: {e}") | |
return "Error extracting text from PDF." | |
# Function to extract text from TXT files | |
def extract_text_from_txt(txt_file): | |
try: | |
return txt_file.getvalue().decode("utf-8") or "No text found in the TXT file." | |
except Exception as e: | |
logger.error(f"Error extracting text from TXT file: {e}") | |
return "Error extracting text from TXT file." | |
# Function to extract text from CSV files | |
def extract_text_from_csv(csv_file): | |
try: | |
df = pd.read_csv(csv_file) | |
return df.to_string(index=False) or "No text found in the CSV file." | |
except Exception as e: | |
logger.error(f"Error extracting text from CSV file: {e}") | |
return "Error extracting text from CSV file." | |
# Function to extract text from DOCX files | |
def extract_text_from_docx(docx_file): | |
try: | |
doc = docx.Document(docx_file) | |
return "\n".join([para.text for para in doc.paragraphs]) or "No text found in the DOCX file." | |
except Exception as e: | |
logger.error(f"Error extracting text from DOCX file: {e}") | |
return "Error extracting text from DOCX file." | |
# Function to create a unique cache key for the document | |
def generate_cache_key(text): | |
return sha256(text.encode('utf-8')).hexdigest() | |
# Function to cache embeddings | |
def cache_embeddings(embeddings, cache_key): | |
try: | |
cache_path = f"embeddings_cache/{cache_key}.pkl" | |
if not os.path.exists('../embeddings_cache'): | |
os.makedirs('../embeddings_cache') | |
with open(cache_path, 'wb') as f: | |
pickle.dump(embeddings, f) | |
logger.info(f"Embeddings cached successfully with key {cache_key}") | |
except Exception as e: | |
logger.error(f"Error caching embeddings: {e}") | |
# Function to load cached embeddings | |
def load_cached_embeddings(cache_key): | |
try: | |
cache_path = f"embeddings_cache/{cache_key}.pkl" | |
if os.path.exists(cache_path): | |
with open(cache_path, 'rb') as f: | |
embeddings = pickle.load(f) | |
logger.info(f"Embeddings loaded from cache with key {cache_key}") | |
return embeddings | |
return None | |
except Exception as e: | |
logger.error(f"Error loading cached embeddings: {e}") | |
return None | |
# Main function for the app | |
def main(): | |
st.title("Adnan AI Labs QA System") | |
st.markdown("Upload documents (PDF, TXT, CSV, or DOCX) or add context manually, and ask questions.") | |
uploaded_files = st.file_uploader("Upload Documents", type=["pdf", "txt", "csv", "docx"], accept_multiple_files=True) | |
extracted_text_box = st.text_area("Manually add extra context for answering questions", height=200) | |
# Initialize QA model | |
qa_pipeline = init_qa_model() | |
document_texts = [] | |
# Extract text from each uploaded file | |
if uploaded_files: | |
for uploaded_file in uploaded_files: | |
if uploaded_file.type == "application/pdf": | |
document_texts.append(extract_text_from_pdf(uploaded_file)) | |
elif uploaded_file.type == "text/plain": | |
document_texts.append(extract_text_from_txt(uploaded_file)) | |
elif uploaded_file.type in ["application/vnd.ms-excel", "text/csv"]: | |
document_texts.append(extract_text_from_csv(uploaded_file)) | |
elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": | |
document_texts.append(extract_text_from_docx(uploaded_file)) | |
# Combine all extracted texts and manual context | |
combined_context = "\n".join(document_texts) + "\n" + extracted_text_box | |
# Check if any content is available to answer questions | |
user_question = st.text_input("Ask a question:") | |
if user_question and combined_context.strip(): | |
if st.button("Get Answer"): | |
with st.spinner('Processing your question...'): | |
# Generate a unique cache key for the combined context | |
cache_key = generate_cache_key(combined_context) | |
# Check for cached embeddings | |
cached_embeddings = load_cached_embeddings(cache_key) | |
if cached_embeddings is None: | |
# Process document embeddings if not cached | |
logger.info("Generating new embeddings...") | |
# embeddings = model.encode(combined_context) | |
cache_embeddings(cached_embeddings, cache_key) # Cache the embeddings | |
# Use the QA pipeline to answer the question | |
answer = qa_pipeline(question=user_question, context=combined_context) | |
if answer['answer']: | |
st.write("Answer:", answer['answer']) | |
else: | |
st.warning("No suitable answer found. Please rephrase your question.") | |
else: | |
if not user_question: | |
st.info("Please enter a question to get an answer.") | |
elif not combined_context.strip(): | |
st.info("Please upload a document or add context manually.") | |
# Display Buy Me a Coffee button | |
st.markdown(""" | |
<div style="text-align: center;"> | |
<p>If you find this project useful, consider buying me a coffee to support further development! ☕️</p> | |
<a href="https://buymeacoffee.com/adnanailabs"> | |
<img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me a Coffee" style="height: 50px;"> | |
</a> | |
</div> | |
""", unsafe_allow_html=True) | |
if __name__ == "__main__": | |
try: | |
main() | |
except Exception as e: | |
logger.critical(f"Critical error: {e}") | |
st.error(f"A critical error occurred: {e}") | |