import os import torch import streamlit as st from streamlit_chat import message from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from langchain.chains import RetrievalQA from langchain.vectorstores import Chroma from langchain.llms import HuggingFacePipeline from langchain.document_loaders import PDFMinerLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from constants import CHROMA_SETTINGS st.set_page_config(layout="centered") checkpoint = "meta-llama/Llama-2-7b-chat-hf" token = os.getenv("HF_TOKEN") tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token=token) model = AutoModelForCausalLM.from_pretrained( checkpoint, use_auth_token=token, device_map="auto", torch_dtype=torch.float32 ) @st.cache_resource def data_ingestion(filepath): loader = PDFMinerLoader(filepath) documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) texts = text_splitter.split_documents(documents) def embedding_function(text): inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device) with torch.no_grad(): embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() return embeddings db = Chroma.from_documents(texts, embedding_function=embedding_function, persist_directory="db") db.persist() db = None @st.cache_resource def llm_pipeline(): pipe = pipeline( 'text-generation', model=model, tokenizer=tokenizer, max_length=256, do_sample=True, temperature=0.3, top_p=0.95 ) local_llm = HuggingFacePipeline(pipeline=pipe) return local_llm @st.cache_resource def qa_llm(): llm = llm_pipeline() def embedding_function(text): inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device) with torch.no_grad(): embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() return embeddings db = Chroma(persist_directory="db", embedding_function=embedding_function) retriever = db.as_retriever() qa = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True ) return qa def process_answer(instruction): qa = qa_llm() generated_text = qa(instruction) answer = generated_text['result'] return answer def display_conversation(history): for i in range(len(history["generated"])): message(history["past"][i], is_user=True, key=str(i) + "_user") message(history["generated"][i], key=str(i)) def main(): st.markdown("