|
from langchain_community.vectorstores import Qdrant |
|
from langchain_groq import ChatGroq |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
import os |
|
from dotenv import load_dotenv |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain.schema.runnable import RunnablePassthrough |
|
from langchain.schema.output_parser import StrOutputParser |
|
from qdrant_client import QdrantClient, models |
|
from langchain_qdrant import Qdrant |
|
import gradio as gr |
|
|
|
|
|
load_dotenv() |
|
|
|
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API") |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5") |
|
|
|
|
|
client = QdrantClient( |
|
url=os.getenv("QDRANT_URL"), |
|
api_key=os.getenv("QDRANT_API_KEY"), |
|
prefer_grpc=True |
|
) |
|
|
|
collection_name = "mawared" |
|
|
|
|
|
try: |
|
client.create_collection( |
|
collection_name=collection_name, |
|
vectors_config=models.VectorParams( |
|
size=768, |
|
distance=models.Distance.COSINE |
|
), |
|
) |
|
print(f"Created new collection: {collection_name}") |
|
except Exception as e: |
|
if "already exists" in str(e): |
|
print(f"Collection {collection_name} already exists, continuing...") |
|
else: |
|
raise e |
|
|
|
|
|
db = Qdrant( |
|
client=client, |
|
collection_name=collection_name, |
|
embeddings=embeddings, |
|
) |
|
|
|
|
|
retriever = db.as_retriever( |
|
search_type="similarity", |
|
search_kwargs={"k": 5} |
|
) |
|
|
|
|
|
llm = ChatGroq( |
|
model="llama-3.3-70b-versatile", |
|
temperature=0.1, |
|
max_tokens=None, |
|
timeout=None, |
|
max_retries=2, |
|
) |
|
|
|
|
|
template = """ |
|
You are an expert assistant specializing in the LONG COT RAG. Your task is to answer the user's question strictly based on the provided context... |
|
Context: |
|
{context} |
|
|
|
Question: |
|
{question} |
|
|
|
Answer: |
|
""" |
|
|
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
|
|
|
rag_chain = ( |
|
{"context": retriever, "question": RunnablePassthrough()} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
|
|
|
|
def ask_question_gradio(question): |
|
result = "" |
|
for chunk in rag_chain.stream(question): |
|
result += chunk |
|
return result |
|
|
|
|
|
interface = gr.Interface( |
|
fn=ask_question_gradio, |
|
inputs="text", |
|
outputs="text", |
|
title="Mawared Expert Assistant", |
|
description="Ask questions about the Mawared HR System or any related topic using Chain-of-Thought (CoT) and RAG principles.", |
|
theme="compact", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|