Sarath0x8f's picture
Upload 5 files
9dc524d verified
raw
history blame
4.9 kB
import gradio as gr
import pymongo
import certifi
from llama_index.core import VectorStoreIndex
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.groq import Groq
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
from llama_index.core.prompts import PromptTemplate
from dotenv import load_dotenv
import os
import base64
import markdown as md
# Load environment variables
load_dotenv()
# --- MongoDB Config ---
# ATLAS_CONNECTION_STRING = "mongodb+srv://sarath:[email protected]/?retryWrites=true&w=majority&appName=Itihasa"
ATLAS_CONNECTION_STRING = os.getenv("ATLAS_CONNECTION_STRING")
DB_NAME = "RAG"
COLLECTION_NAME = "ramayana"
VECTOR_INDEX_NAME = "ramayana_vector_index"
# --- Embedding Model ---
embed_model = HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-base")
# --- Prompt Template ---
ramayana_qa_template = PromptTemplate(
"""You are an expert on the Valmiki Ramayana and a guide who always inspires people with the great Itihasa like the Ramayana.
Below is text from the epic, including shlokas and their explanations:
---------------------
{context_str}
---------------------
Using only this information, answer the following query.
Query: {query_str}
Answer:
- Intro or general description to ```Query```
- Related shloka/shlokas followed by its explanation
- Overview of ```Query```
"""
)
# --- Connect to MongoDB once at startup ---
def get_vector_index_once():
mongo_client = pymongo.MongoClient(
ATLAS_CONNECTION_STRING,
tlsCAFile=certifi.where(),
tlsAllowInvalidCertificates=False,
connectTimeoutMS=30000,
serverSelectionTimeoutMS=30000,
)
mongo_client.server_info()
print("βœ… Connected to MongoDB Atlas.")
vector_store = MongoDBAtlasVectorSearch(
mongo_client,
db_name=DB_NAME,
collection_name=COLLECTION_NAME,
vector_index_name=VECTOR_INDEX_NAME,
)
return VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model)
# Connect once
vector_index = get_vector_index_once()
# --- Respond Function (uses API key from state) ---
def chat_with_groq(message, history, groq_key):
llm = Groq(model="llama-3.1-8b-instant", api_key=groq_key)
query_engine = vector_index.as_query_engine(
llm=llm,
text_qa_template=ramayana_qa_template,
similarity_top_k=5,
verbose=True,
)
response = query_engine.query(message)
return str(response)
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
# Encode the images
github_logo_encoded = encode_image("Images/github-logo.png")
linkedin_logo_encoded = encode_image("Images/linkedin-logo.png")
website_logo_encoded = encode_image("Images/ai-logo.png")
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Roboto Mono")]), css='footer {visibility: hidden}') as demo:
with gr.Tabs():
with gr.TabItem("Intro"):
gr.Markdown(md.description)
with gr.TabItem("GPT"):
with gr.Column(visible=True) as accordion_container:
with gr.Accordion("How to get Groq API KEY", open=False):
gr.Markdown(md.groq_api_key)
groq_key_box = gr.Textbox(
label="Enter Groq API Key",
type="password",
placeholder="Paste your Groq API key here..."
)
start_btn = gr.Button("Start Chat")
groq_state = gr.State(value="")
# Chat container, initially hidden
with gr.Column(visible=False) as chatbot_container:
chatbot = gr.ChatInterface(
fn=lambda message, history, groq_key: chat_with_groq(message, history, groq_key),
additional_inputs=[groq_state],
title="πŸ•‰οΈ RamayanaGPT",
# description="Ask questions from the Valmiki Ramayana. Powered by RAG + MongoDB + LlamaIndex.",
)
# Show chat and hide inputs
def save_key_and_show_chat(key):
return key, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
start_btn.click(
fn=save_key_and_show_chat,
inputs=[groq_key_box],
outputs=[groq_state, groq_key_box, start_btn, accordion_container, chatbot_container]
)
gr.HTML(md.footer.format(github_logo_encoded, linkedin_logo_encoded, website_logo_encoded))
if __name__ == "__main__":
demo.launch()