|
|
|
|
|
import os |
|
from typing import List |
|
from langchain_groq import ChatGroq |
|
from langchain.prompts import PromptTemplate |
|
from langchain_community.vectorstores import Qdrant |
|
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings |
|
from qdrant_client import QdrantClient |
|
from langchain_community.chat_models import ChatOllama |
|
|
|
|
|
import chainlit as cl |
|
from langchain.chains import RetrievalQA |
|
|
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
groq_api_key = os.getenv("GROQ_API_KEY") |
|
qdrant_url = os.getenv("QDRANT_URL") |
|
qdrant_api_key = os.getenv("QDRANT_API_KEY") |
|
|
|
custom_prompt_template = """Use the following pieces of information to answer the user's question. |
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
|
|
Context: {context} |
|
Question: {question} |
|
|
|
Only return the helpful answer below and nothing else. |
|
Helpful answer: |
|
""" |
|
|
|
def set_custom_prompt(): |
|
""" |
|
Prompt template for QA retrieval for each vectorstore |
|
""" |
|
prompt = PromptTemplate(template=custom_prompt_template, |
|
input_variables=['context', 'question']) |
|
return prompt |
|
|
|
|
|
chat_model = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768") |
|
|
|
|
|
|
|
client = QdrantClient(api_key=qdrant_api_key, url=qdrant_url,) |
|
|
|
|
|
def retrieval_qa_chain(llm, prompt, vectorstore): |
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=vectorstore.as_retriever(search_kwargs={'k': 2}), |
|
return_source_documents=True, |
|
chain_type_kwargs={'prompt': prompt} |
|
) |
|
return qa_chain |
|
|
|
|
|
def qa_bot(): |
|
embeddings = FastEmbedEmbeddings() |
|
vectorstore = Qdrant(client=client, embeddings=embeddings, collection_name="rag") |
|
llm = chat_model |
|
qa_prompt=set_custom_prompt() |
|
qa = retrieval_qa_chain(llm, qa_prompt, vectorstore) |
|
return qa |
|
|
|
|
|
@cl.set_chat_profiles |
|
async def chat_profile(): |
|
return [ |
|
cl.ChatProfile( |
|
name="Virtual Tutor", |
|
markdown_description="The underlying LLM model is **Mixtral**.", |
|
icon="https://www.google.com/url?sa=i&url=https%3A%2F%2Fwww.gptshunter.com%2Fgpt-store%2FNTQyNjE4MGMyMzU1MTcyNjU4&psig=AOvVaw3dz6CEyBeDM9iyj8gcEwNI&ust=1711780055423000&source=images&cd=vfe&opi=89978449&ved=0CBAQjRxqFwoTCNjprerrmIUDFQAAAAAdAAAAABAK", |
|
), |
|
] |
|
|
|
@cl.on_chat_start |
|
async def start(): |
|
""" |
|
Initializes the bot when a new chat starts. |
|
|
|
This asynchronous function creates a new instance of the retrieval QA bot, |
|
sends a welcome message, and stores the bot instance in the user's session. |
|
""" |
|
await cl.Avatar( |
|
name="Tool 1", |
|
url="https://www.google.com/url?sa=i&url=https%3A%2F%2Fwww.gptshunter.com%2Fgpt-store%2FNTQyNjE4MGMyMzU1MTcyNjU4&psig=AOvVaw3dz6CEyBeDM9iyj8gcEwNI&ust=1711780055423000&source=images&cd=vfe&opi=89978449&ved=0CBAQjRxqFwoTCNjprerrmIUDFQAAAAAdAAAAABAK",).send() |
|
|
|
chain = qa_bot() |
|
welcome_message = cl.Message(content="Starting the bot...") |
|
await welcome_message.send() |
|
welcome_message.content = ( |
|
"Welcome to Virtual Tutor." |
|
) |
|
await welcome_message.update() |
|
cl.user_session.set("chain", chain) |
|
|
|
|
|
@cl.on_message |
|
async def main(message): |
|
""" |
|
Processes incoming chat messages. |
|
|
|
This asynchronous function retrieves the QA bot instance from the user's session, |
|
sets up a callback handler for the bot's response, and executes the bot's |
|
call method with the given message and callback. The bot's answer and source |
|
documents are then extracted from the response. |
|
""" |
|
chain = cl.user_session.get("chain") |
|
cb = cl.AsyncLangchainCallbackHandler() |
|
cb.answer_reached = True |
|
|
|
res = await chain.acall(message.content, callbacks=[cb]) |
|
|
|
answer = res["result"] |
|
|
|
source_documents = res["source_documents"] |
|
|
|
text_elements = [] |
|
|
|
if source_documents: |
|
for source_idx, source_doc in enumerate(source_documents): |
|
source_name = f"source_{source_idx}" |
|
|
|
text_elements.append( |
|
cl.Text(content=source_doc.page_content, name=source_name) |
|
) |
|
source_names = [text_el.name for text_el in text_elements] |
|
|
|
if source_names: |
|
answer += f"\nSources: {', '.join(source_names)}" |
|
else: |
|
answer += "\nNo sources found" |
|
|
|
await cl.Message(content=answer, elements=text_elements,author="Tool 1").send() |