File size: 3,905 Bytes
d8369f5
 
 
 
 
 
 
 
 
 
9da3cb1
 
d8369f5
 
 
 
 
 
 
 
 
 
 
 
b76a7a5
d8369f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import streamlit as st

from langchain_text_splitters import Language
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from transformers import pipeline
from langchain import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
import torch



gpt_model = 'gpt-4-1106-preview'
embedding_model = 'text-embedding-3-small'

def init():
    if "conversation" not in st.session_state:
        st.session_state.conversation = None
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = None        

def init_llm_pipeline():
    if "llm" not in st.session_state:
        model_id = "bigcode/starcoder2-15b"
        quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16
        )

        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=quantization_config,
            device_map="auto",
        )
        tokenizer.add_eos_token = True
        tokenizer.pad_token_id = 0
        tokenizer.padding_side = "left"

        text_generation_pipeline = pipeline(
        model=model,
        tokenizer=tokenizer,
        task="text-generation",
        temperature=0.7,
        repetition_penalty=1.1,
        return_full_text=True,
        max_new_tokens=300,
        )
        st.session_state.llm = HuggingFacePipeline(pipeline=text_generation_pipeline)          

def get_text(docs):
    return docs.getvalue().decode("utf-8")

def get_vectorstore(documents):
    python_splitter = RecursiveCharacterTextSplitter.from_language(
        language=Language.PYTHON, chunk_size=2000, chunk_overlap=200
    )
    texts = python_splitter.split_documents(documents)

    embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

    db = FAISS.from_documents(texts, embeddings)
    retriever = db.as_retriever(
        search_type="mmr",  # Also test "similarity"
        search_kwargs={"k": 8},
    )
    return retriever
    
def get_conversation(retriever):
    memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
    conversation_chain = ConversationalRetrievalChain.from_llm(
        llm=st.session_state.llm,
        retriever=retriever,
        memory = memory   
    )
    return conversation_chain

def handle_user_input(question):
    response = st.session_state.conversation({'question':question})
    st.session_state.chat_history = response['chat_history']
    for i, message in enumerate(st.session_state.chat_history):
        if i % 2 == 0:
            with st.chat_message("user"):
                st.write(message.content)
        else:
            with st.chat_message("assistant"):
                st.write(message.content)

def main():
    #load_dotenv()
    init()

    st.set_page_config(page_title="Coding-Assistent", page_icon=":books:")

    st.header(":books: Coding-Assistent ")
    user_input = st.chat_input("Stellen Sie Ihre Frage hier")
    if user_input:
        with st.spinner("Führe Anfrage aus ..."):        
            handle_user_input(user_input)


    with st.sidebar:
        st.subheader("Code Upload")
        upload_docs=st.file_uploader("Dokumente hier hochladen", accept_multiple_files=True)
        if st.button("Hochladen"):
            with st.spinner("Analysiere Dokumente ..."):
                init_llm_pipeline()
                raw_text = get_text(upload_docs)
                vectorstore = get_vectorstore(raw_text)
                st.session_state.conversation = get_conversation(vectorstore) 


if __name__ == "__main__":
    main()