import os import torch import streamlit as st from streamlit_chat import message from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from langchain.chains import RetrievalQA from langchain.vectorstores import Chroma from langchain.llms import HuggingFacePipeline from langchain.document_loaders import PDFMinerLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from constants import CHROMA_SETTINGS st.set_page_config(layout="centered") checkpoint = "MBZUAI/LaMini-T5-738M" tokenizer = AutoTokenizer.from_pretrained(checkpoint) model = AutoModelForSeq2SeqLM.from_pretrained( checkpoint, device_map="auto", torch_dtype=torch.float32 ) @st.cache_resource def data_ingestion(filepath): loader = PDFMinerLoader(filepath) documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) texts = text_splitter.split_documents(documents) def embedding_function(text): inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device) with torch.no_grad(): embeddings = model.encoder(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() return embeddings db = Chroma.from_documents(texts, persist_directory="db", embedding_function=embedding_function) db.persist() db = None @st.cache_resource def llm_pipeline(): pipe = pipeline( 'text2text-generation', model=model, tokenizer=tokenizer, max_length=256, do_sample=True, temperature=0.3, top_p=0.95 ) local_llm = HuggingFacePipeline(pipeline=pipe) return local_llm @st.cache_resource def qa_llm(): llm = llm_pipeline() def embedding_function(text): inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device) with torch.no_grad(): embeddings = model.encoder(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() return embeddings db = Chroma(persist_directory="db", embedding_function=embedding_function) retriever = db.as_retriever() qa = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True ) return qa def process_answer(instruction): qa = qa_llm() generated_text = qa(instruction) answer = generated_text['result'] return answer def display_conversation(history): for i in range(len(history["generated"])): message(history["past"][i], is_user=True, key=str(i) + "_user") message(history["generated"][i], key=str(i)) def main(): st.markdown("

Chat with your PDF

", unsafe_allow_html=True) st.markdown("

Upload your PDF

", unsafe_allow_html=True) uploaded_file = st.file_uploader("", type=["pdf"]) if uploaded_file is not None: # Ensure the 'docs' directory exists if not os.path.exists("docs"): os.makedirs("docs") filepath = "docs/" + uploaded_file.name with open(filepath, "wb") as temp_file: temp_file.write(uploaded_file.read()) with st.spinner('Embeddings are creating...'): data_ingestion(filepath) st.success('Embeddings are created successfully!') user_input = st.text_input("", key="input") if "generated" not in st.session_state: st.session_state["generated"] = ["I am ready to help you"] if "past" not in st.session_state: st.session_state["past"] = ["Hey there!"] if user_input: answer = process_answer({'query': user_input}) st.session_state["past"].append(user_input) st.session_state["generated"].append(answer) display_conversation(st.session_state) if __name__ == "__main__": main()