starcoder2-test / app.py
ag-mach's picture
imports corrected
9da3cb1
raw
history blame
3.91 kB
import streamlit as st
from langchain_text_splitters import Language
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from transformers import pipeline
from langchain import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
import torch
gpt_model = 'gpt-4-1106-preview'
embedding_model = 'text-embedding-3-small'
def init():
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
def init_llm_pipeline():
if "llm" not in st.session_state:
model_id = "bigcode/starcoder2-15b"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
)
tokenizer.add_eos_token = True
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
text_generation_pipeline = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
temperature=0.7,
repetition_penalty=1.1,
return_full_text=True,
max_new_tokens=300,
)
st.session_state.llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
def get_text(docs):
return docs.getvalue().decode("utf-8")
def get_vectorstore(documents):
python_splitter = RecursiveCharacterTextSplitter.from_language(
language=Language.PYTHON, chunk_size=2000, chunk_overlap=200
)
texts = python_splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
db = FAISS.from_documents(texts, embeddings)
retriever = db.as_retriever(
search_type="mmr", # Also test "similarity"
search_kwargs={"k": 8},
)
return retriever
def get_conversation(retriever):
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=st.session_state.llm,
retriever=retriever,
memory = memory
)
return conversation_chain
def handle_user_input(question):
response = st.session_state.conversation({'question':question})
st.session_state.chat_history = response['chat_history']
for i, message in enumerate(st.session_state.chat_history):
if i % 2 == 0:
with st.chat_message("user"):
st.write(message.content)
else:
with st.chat_message("assistant"):
st.write(message.content)
def main():
#load_dotenv()
init()
st.set_page_config(page_title="Coding-Assistent", page_icon=":books:")
st.header(":books: Coding-Assistent ")
user_input = st.chat_input("Stellen Sie Ihre Frage hier")
if user_input:
with st.spinner("Führe Anfrage aus ..."):
handle_user_input(user_input)
with st.sidebar:
st.subheader("Code Upload")
upload_docs=st.file_uploader("Dokumente hier hochladen", accept_multiple_files=True)
if st.button("Hochladen"):
with st.spinner("Analysiere Dokumente ..."):
init_llm_pipeline()
raw_text = get_text(upload_docs)
vectorstore = get_vectorstore(raw_text)
st.session_state.conversation = get_conversation(vectorstore)
if __name__ == "__main__":
main()