import streamlit as st from langchain_text_splitters import Language from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import FAISS from langchain.memory import ConversationBufferMemory from langchain.chains import ConversationalRetrievalChain from transformers import pipeline from langchain import HuggingFacePipeline from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline import torch gpt_model = 'gpt-4-1106-preview' embedding_model = 'text-embedding-3-small' def init(): if "conversation" not in st.session_state: st.session_state.conversation = None if "chat_history" not in st.session_state: st.session_state.chat_history = None def init_llm_pipeline(): if "llm" not in st.session_state: model_id = "bigcode/starcoder2-15b" quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16 ) tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=quantization_config, device_map="auto", ) tokenizer.add_eos_token = True tokenizer.pad_token_id = 0 tokenizer.padding_side = "left" text_generation_pipeline = pipeline( model=model, tokenizer=tokenizer, task="text-generation", temperature=0.7, repetition_penalty=1.1, return_full_text=True, max_new_tokens=300, ) st.session_state.llm = HuggingFacePipeline(pipeline=text_generation_pipeline) def get_text(docs): return docs.getvalue().decode("utf-8") def get_vectorstore(documents): python_splitter = RecursiveCharacterTextSplitter.from_language( language=Language.PYTHON, chunk_size=2000, chunk_overlap=200 ) texts = python_splitter.split_documents(documents) embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") db = FAISS.from_documents(texts, embeddings) retriever = db.as_retriever( search_type="mmr", # Also test "similarity" search_kwargs={"k": 8}, ) return retriever def get_conversation(retriever): memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True) conversation_chain = ConversationalRetrievalChain.from_llm( llm=st.session_state.llm, retriever=retriever, memory = memory ) return conversation_chain def handle_user_input(question): response = st.session_state.conversation({'question':question}) st.session_state.chat_history = response['chat_history'] for i, message in enumerate(st.session_state.chat_history): if i % 2 == 0: with st.chat_message("user"): st.write(message.content) else: with st.chat_message("assistant"): st.write(message.content) def main(): #load_dotenv() init() st.set_page_config(page_title="Coding-Assistent", page_icon=":books:") st.header(":books: Coding-Assistent ") user_input = st.chat_input("Stellen Sie Ihre Frage hier") if user_input: with st.spinner("Führe Anfrage aus ..."): handle_user_input(user_input) with st.sidebar: st.subheader("Code Upload") upload_docs=st.file_uploader("Dokumente hier hochladen", accept_multiple_files=True) if st.button("Hochladen"): with st.spinner("Analysiere Dokumente ..."): init_llm_pipeline() raw_text = get_text(upload_docs) vectorstore = get_vectorstore(raw_text) st.session_state.conversation = get_conversation(vectorstore) if __name__ == "__main__": main()