starcoder2-test / app.py
ag-mach's picture
inputfield for model id
bbdafb0
raw
history blame
3.86 kB
import streamlit as st
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
gpt_model = 'gpt-4-1106-preview'
embedding_model = 'text-embedding-3-small'
default_model_id = "bigcode/starcoder2-7b"
def init():
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
def init_llm_pipeline(model_id):
if "llm" not in st.session_state:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
)
tokenizer.add_eos_token = True
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
text_generation_pipeline = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
temperature=0.2,
repetition_penalty=1.1,
return_full_text=True,
max_new_tokens=300,
)
st.session_state.llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
def get_text(docs):
return docs.getvalue().decode("utf-8")
def get_vectorstore(documents):
python_splitter = RecursiveCharacterTextSplitter.from_language(
language=Language.PYTHON, chunk_size=2000, chunk_overlap=200
)
texts = python_splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
db = FAISS.from_documents(texts, embeddings)
retriever = db.as_retriever(
search_type="mmr", # Also test "similarity"
search_kwargs={"k": 8},
)
return retriever
def get_conversation(retriever):
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=st.session_state.llm,
retriever=retriever,
memory = memory
)
return conversation_chain
def handle_user_input(question):
response = st.session_state.conversation({'question':question})
st.session_state.chat_history = response['chat_history']
for i, message in enumerate(st.session_state.chat_history):
if i % 2 == 0:
with st.chat_message("user"):
st.write(message.content)
else:
with st.chat_message("assistant"):
st.write(message.content)
def main():
#load_dotenv()
init()
st.set_page_config(page_title="Coding-Assistent", page_icon=":books:")
st.header(":books: Coding-Assistent ")
user_input = st.chat_input("Stellen Sie Ihre Frage hier")
if user_input:
with st.spinner("Führe Anfrage aus ..."):
handle_user_input(user_input)
with st.sidebar:
st.subheader("Model selector")
model_id = st.text_input("Modelname on HuggingFace", default_model_id)
st.subheader("Code Upload")
upload_docs=st.file_uploader("Dokumente hier hochladen", accept_multiple_files=True)
if st.button("Hochladen"):
with st.spinner("Analysiere Dokumente ..."):
init_llm_pipeline(model_id)
raw_text = get_text(upload_docs)
vectorstore = get_vectorstore(raw_text)
st.session_state.conversation = get_conversation(vectorstore)
if __name__ == "__main__":
main()