import transformers import torch import os from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from transformers import pipeline from langchain.llms import HuggingFacePipeline from langchain.vectorstores import Chroma from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains import RetrievalQA from langchain.document_loaders import TextLoader from langchain.document_loaders import PyPDFLoader from langchain.document_loaders import DirectoryLoader from InstructorEmbedding import INSTRUCTOR from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain_community.vectorstores import Chroma import textwrap import streamlit as st persist_directory = 'db' instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-base") embedding = instructor_embeddings #tokenizer = AutoTokenizer.from_pretrained("lmsys/fastchat-t5-3b-v1.0") tokenizer = AutoTokenizer.from_pretrained("lmsys/fastchat-t5-3b-v1.0",use_fast=False, legacy=False) model = AutoModelForSeq2SeqLM.from_pretrained("lmsys/fastchat-t5-3b-v1.0") pipe = pipeline("text2text-generation",model=model, tokenizer=tokenizer,max_length=256) local_llm = HuggingFacePipeline(pipeline=pipe) vectordb = Chroma(persist_directory=persist_directory,embedding_function=embedding) retriever = vectordb.as_retriever(search_kwargs={"k": 3}) def get_lpphelper_chain(): qa_chain = RetrievalQA.from_chain_type(llm=local_llm, chain_type="stuff", retriever=retriever, return_source_documents=True) return qa_chain def wrap_text_preserve_newlines(text, width=110): # Split the input text into lines based on newline characters lines = text.split('\n') # Wrap each line individually wrapped_lines = [textwrap.fill(line, width=width) for line in lines] # Join the wrapped lines back together using newline characters wrapped_text = '\n'.join(wrapped_lines) return wrapped_text def process_llm_response(llm_response): wrap_text = wrap_text_preserve_newlines(llm_response['result']) sources = '\n\nSources:' print('\n\nSources:') for source in llm_response["source_documents"]: sources.join(source.metadata['source']) print(wrap_text.join(sources)) return wrap_text.replace("","") if __name__=="__main__": get_lpphelper_chain()