import os import torch import streamlit as st from streamlit_chat import message from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from langchain.chains import RetrievalQA from langchain.vectorstores import Chroma from langchain.llms import HuggingFacePipeline from langchain.document_loaders import PDFMinerLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from constants import CHROMA_SETTINGS st.set_page_config(layout="centered") checkpoint = "MBZUAI/LaMini-T5-738M" tokenizer = AutoTokenizer.from_pretrained(checkpoint) model = AutoModelForSeq2SeqLM.from_pretrained( checkpoint, device_map="auto", torch_dtype=torch.float32 ) @st.cache_resource def data_ingestion(filepath): loader = PDFMinerLoader(filepath) documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) texts = text_splitter.split_documents(documents) def embedding_function(text): inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device) with torch.no_grad(): embeddings = model.encoder(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() return embeddings db = Chroma.from_documents(texts, persist_directory="db", embedding_function=embedding_function) db.persist() db = None @st.cache_resource def llm_pipeline(): pipe = pipeline( 'text2text-generation', model=model, tokenizer=tokenizer, max_length=256, do_sample=True, temperature=0.3, top_p=0.95 ) local_llm = HuggingFacePipeline(pipeline=pipe) return local_llm @st.cache_resource def qa_llm(): llm = llm_pipeline() def embedding_function(text): inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device) with torch.no_grad(): embeddings = model.encoder(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() return embeddings db = Chroma(persist_directory="db", embedding_function=embedding_function) retriever = db.as_retriever() qa = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True ) return qa def process_answer(instruction): qa = qa_llm() generated_text = qa(instruction) answer = generated_text['result'] return answer def display_conversation(history): for i in range(len(history["generated"])): message(history["past"][i], is_user=True, key=str(i) + "_user") message(history["generated"][i], key=str(i)) def main(): st.markdown("