|
|
|
|
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
|
model_name = "bn22/Mistral-7B-Instruct-v0.1-sharded" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_quantized_model(model_name: str): |
|
""" |
|
:param model_name: Name or path of the model to be loaded. |
|
:return: Loaded quantized model. |
|
""" |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
load_in_4bit=True, |
|
torch_dtype=torch.bfloat16, |
|
quantization_config=bnb_config |
|
) |
|
return model |
|
|
|
|
|
|
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline |
|
|
|
from langchain_core.messages import AIMessage, HumanMessage |
|
from langchain_community.document_loaders import WebBaseLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import Chroma |
|
|
|
|
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.vectorstores.faiss import FAISS |
|
|
|
|
|
from dotenv import load_dotenv |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
|
|
|
|
load_dotenv() |
|
|
|
def get_vectorstore_from_url(url): |
|
|
|
loader = WebBaseLoader(url) |
|
document = loader.load() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter() |
|
document_chunks = text_splitter.split_documents(document) |
|
|
|
''' |
|
FAISS |
|
A FAISS vector store containing the embeddings of the text chunks. |
|
''' |
|
model = "BAAI/bge-base-en-v1.5" |
|
encode_kwargs = { |
|
"normalize_embeddings": True |
|
} |
|
embeddings = HuggingFaceBgeEmbeddings( |
|
model_name=model, encode_kwargs=encode_kwargs, model_kwargs={"device": "cpu"} |
|
) |
|
|
|
vector_store = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) |
|
|
|
|
|
vector_store = Chroma.from_documents(document_chunks, embeddings, persist_directory="./chroma_db") |
|
|
|
|
|
|
|
|
|
print("-----") |
|
print(vector_store.similarity_search("What is ALiBi?")) |
|
print("-----") |
|
|
|
|
|
|
|
|
|
return vector_store |
|
|
|
|
|
|
|
|
|
|
|
def get_context_retriever_chain(vector_store): |
|
|
|
|
|
model_name = "anakin87/zephyr-7b-alpha-sharded" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm = load_quantized_model(model_name) |
|
|
|
retriever = vector_store.as_retriever() |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("user", "{input}"), |
|
("user", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation") |
|
]) |
|
|
|
retriever_chain = create_history_aware_retriever(llm, retriever, prompt) |
|
|
|
return retriever_chain |
|
|
|
def get_conversational_rag_chain(retriever_chain): |
|
|
|
llm = load_quantized_model(model_name) |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "Answer the user's questions based on the below context:\n\n{context}"), |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("user", "{input}"), |
|
]) |
|
|
|
stuff_documents_chain = create_stuff_documents_chain(llm,prompt) |
|
|
|
return create_retrieval_chain(retriever_chain, stuff_documents_chain) |
|
|
|
def get_response(user_input): |
|
retriever_chain = get_context_retriever_chain(st.session_state.vector_store) |
|
conversation_rag_chain = get_conversational_rag_chain(retriever_chain) |
|
|
|
response = conversation_rag_chain.invoke({ |
|
"chat_history": st.session_state.chat_history, |
|
"input": user_query |
|
}) |
|
|
|
return response['answer'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
chat_history = [] |
|
|
|
|
|
def get_response(user_input): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vs = get_vectorstore_from_url("https://www.bofrost.de/shop/laenderkueche_5573/italienische-kueche_5576/linguine-mit-feinen-pilzen.html?position=1&clicked=") |
|
print("------ here 22 " ) |
|
chat_history =[] |
|
retriever_chain = get_context_retriever_chain(vs) |
|
conversation_rag_chain = get_conversational_rag_chain(retriever_chain) |
|
|
|
response = conversation_rag_chain.invoke({ |
|
"chat_history": chat_history, |
|
"input": user_input |
|
}) |
|
|
|
return response['answer'] |
|
|
|
def simple(text:str): |
|
return text +" hhhmmm " |
|
|
|
app = gr.Interface( |
|
fn=get_response, |
|
|
|
inputs=["text"], |
|
outputs="text", |
|
title="Chat with Websites", |
|
description="Type your message and chat with websites.", |
|
|
|
) |
|
|
|
app.launch(debug=True, share=True) |