|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
from langchain.llms import HuggingFaceHub |
|
model_name = "bn22/Mistral-7B-Instruct-v0.1-sharded" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(model_name: str): |
|
|
|
model = HuggingFaceHub( |
|
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
model_kwargs={"max_length": 1048, "temperature":0.2, "max_new_tokens":256, "top_p":0.95, "repetition_penalty":1.0}, |
|
) |
|
|
|
""" |
|
:param model_name: Name or path of the model to be loaded. |
|
:return: Loaded quantized model. |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
load_in_4bit=True, |
|
torch_dtype=torch.bfloat16, |
|
quantization_config=bnb_config |
|
)""" |
|
return model |
|
|
|
|
|
|
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline |
|
|
|
from langchain_core.messages import AIMessage, HumanMessage |
|
from langchain_community.document_loaders import WebBaseLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import Chroma |
|
|
|
|
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.vectorstores.faiss import FAISS |
|
|
|
|
|
from dotenv import load_dotenv |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
|
|
|
|
load_dotenv() |
|
|
|
def get_vectorstore_from_url(url): |
|
|
|
loader = WebBaseLoader(url) |
|
document = loader.load() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter() |
|
document_chunks = text_splitter.split_documents(document) |
|
|
|
''' |
|
FAISS |
|
A FAISS vector store containing the embeddings of the text chunks. |
|
''' |
|
model = "BAAI/bge-base-en-v1.5" |
|
encode_kwargs = { |
|
"normalize_embeddings": True |
|
} |
|
embeddings = HuggingFaceBgeEmbeddings( |
|
model_name=model, encode_kwargs=encode_kwargs, model_kwargs={"device": "cpu"} |
|
) |
|
|
|
vector_store = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) |
|
|
|
|
|
vector_store = Chroma.from_documents(document_chunks, embeddings, persist_directory="./chroma_db") |
|
|
|
|
|
|
|
|
|
print("-----") |
|
print(vector_store.similarity_search("What is ALiBi?")) |
|
print("-----") |
|
|
|
|
|
|
|
|
|
return vector_store |
|
|
|
|
|
|
|
|
|
|
|
def get_context_retriever_chain(vector_store): |
|
|
|
|
|
model_name = "anakin87/zephyr-7b-alpha-sharded" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm = load_model(model_name) |
|
|
|
retriever = vector_store.as_retriever() |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("user", "{input}"), |
|
("user", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation") |
|
]) |
|
|
|
retriever_chain = create_history_aware_retriever(llm, retriever, prompt) |
|
|
|
return retriever_chain |
|
|
|
def get_conversational_rag_chain(retriever_chain): |
|
|
|
llm = load_model(model_name) |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "Du bist ein freundlicher Mitarbeiter einens Call Center und beantwortest basierend auf dem Context. Benutze nur den Inhalt des Context. Antworte mit: Ich bin mir nicht sicher. Wenn die Antwort nicht aus dem Context hervorgeht. Antworte auf Deutsch, bitte? CONTEXT:\n\n{context}"), |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("user", "{input}"), |
|
]) |
|
|
|
stuff_documents_chain = create_stuff_documents_chain(llm,prompt) |
|
|
|
return create_retrieval_chain(retriever_chain, stuff_documents_chain) |
|
|
|
def get_response(user_input): |
|
retriever_chain = get_context_retriever_chain(st.session_state.vector_store) |
|
conversation_rag_chain = get_conversational_rag_chain(retriever_chain) |
|
|
|
response = conversation_rag_chain.invoke({ |
|
"chat_history": st.session_state.chat_history, |
|
"input": user_query |
|
}) |
|
|
|
return response['answer'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
chat_history = [] |
|
|
|
|
|
def get_response(user_input): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vs = get_vectorstore_from_url("https://www.bofrost.de/shop/fertige-gerichte_5507/auflaeufe_5509/hack-wirsing-auflauf.html?position=1&clicked=") |
|
print("------ here 22 " ) |
|
chat_history =[] |
|
retriever_chain = get_context_retriever_chain(vs) |
|
conversation_rag_chain = get_conversational_rag_chain(retriever_chain) |
|
|
|
response = conversation_rag_chain.invoke({ |
|
"chat_history": chat_history, |
|
"input": user_input |
|
}) |
|
|
|
return response['answer'] |
|
|
|
|
|
def history_to_dialog_format(chat_history: list[str]): |
|
dialog = [] |
|
if len(chat_history) > 0: |
|
for idx, message in enumerate(chat_history[0]): |
|
role = "user" if idx % 2 == 0 else "assistant" |
|
dialog.append({ |
|
"role": role, |
|
"content": message, |
|
}) |
|
return dialog |
|
|
|
def get_response(message, history): |
|
dialog = history_to_dialog_format(history) |
|
dialog.append({"role": "user", "content": message}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vs = get_vectorstore_from_url("https://www.bofrost.de/shop/fertige-gerichte_5507/auflaeufe_5509/hack-wirsing-auflauf.html?position=1&clicked=") |
|
print("------ here 22 " ) |
|
history =[] |
|
retriever_chain = get_context_retriever_chain(vs) |
|
conversation_rag_chain = get_conversational_rag_chain(retriever_chain) |
|
|
|
response = conversation_rag_chain.invoke({ |
|
"chat_history": history, |
|
"input": message |
|
}) |
|
print(response) |
|
return response |
|
|
|
|
|
|
|
vs = get_vectorstore_from_url("https://www.bofrost.de/faq/") |
|
vs = get_vectorstore_from_url("https://www.bofrost.de/shop/kartoffelprodukte_5539/pommes-frites_5540/mikrowellen-pommes.html?position=7&clicked=") |
|
vs = get_vectorstore_from_url("https://www.bofrost.de/shop/kartoffelprodukte_5539/pommes-frites_5540/backofen-knusper-frites-1200-g.html?position=1&clicked=search") |
|
vs = get_vectorstore_from_url("https://www.bofrost.de/shop/laenderkueche_5573/asiatische-kueche_5574/chinesische-bratnudeln.html?emcs0=1&emcs1=Produktdetailseite&emcs2=00554&emcs3=01270&clicked=recommendation&position=2") |
|
vs = get_vectorstore_from_url("https://www.bofrost.de/shop/fertige-gerichte_5507/pfannengerichte_5508/westfaelisches-gruenkohlgericht.html?emcs0=98&emcs1=Produktdetailseite&emcs2=00170&emcs3=00554&clicked=recommendation&position=1") |
|
|
|
|
|
def simple(text:str): |
|
return text +" hhhmmm " |
|
|
|
app = gr.ChatInterface( |
|
fn=get_response, |
|
|
|
|
|
|
|
title="Chat with Websites", |
|
description="Schreibe hier deine Frage rein...", |
|
|
|
retry_btn=None, |
|
undo_btn=None, |
|
clear_btn=None |
|
) |
|
|
|
app.launch(debug=True, share=True) |