SRUNU / app.py
Srinivasulu kethanaboina
Update app.py
2d5fd5d verified
raw
history blame
4.75 kB
from dotenv import load_dotenv
from fastapi import FastAPI, Request
import gradio as gr
import os
import firebase_admin
from firebase_admin import db, credentials
import datetime
import uuid
import random
from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import threading
import uvicorn
# Load environment variables
load_dotenv()
# Authenticate to Firebase
cred = credentials.Certificate("redfernstech-fd8fe-firebase-adminsdk-g9vcn-0537b4efd6.json")
firebase_admin.initialize_app(cred, {"databaseURL": "https://redfernstech-fd8fe-default-rtdb.firebaseio.com/"})
app = FastAPI()
# Configure the Llama index settings
Settings.llm = HuggingFaceInferenceAPI(
model_name="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct",
context_window=3000,
token=os.getenv("HF_TOKEN"),
max_new_tokens=512,
generate_kwargs={"temperature": 0.1},
)
Settings.embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en-v1.5"
)
# Define directories
PERSIST_DIR = "db"
PDF_DIRECTORY = 'data'
os.makedirs(PDF_DIRECTORY, exist_ok=True)
os.makedirs(PERSIST_DIR, exist_ok=True)
# Variable to store chat history
current_chat_history = []
def data_ingestion_from_directory():
documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
storage_context = StorageContext.from_defaults()
index = VectorStoreIndex.from_documents(documents)
index.storage_context.persist(persist_dir=PERSIST_DIR)
def handle_query(query):
chat_text_qa_msgs = [
(
"user",
"""
You are Clara, the Redfernstech chatbot. Provide accurate, professional, and concise answers based on the data. Respond within 10-15 words only.
{context_str}
Question:
{query_str}
"""
)
]
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
context_str = ""
for past_query, response in reversed(current_chat_history):
if past_query.strip():
context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
answer = query_engine.query(query)
response = answer.response if hasattr(answer, 'response') else "Sorry, I couldn't find an answer."
current_chat_history.append((query, response))
return response
def save_chat_message(username, email, session_id, message_data):
ref = db.reference(f'/chat_history/{username}/{email}/{session_id}')
ref.push().set(message_data)
def chat_interface(message, history, request: gr.Request):
try:
username = request.query_params.get('username')
email = request.query_params.get('email')
session_id = str(uuid.uuid4())
response = handle_query(message)
message_data = {
"sender": request.client.host,
"message": message,
"response": response,
"timestamp": datetime.datetime.now().isoformat()
}
save_chat_message(username, email, session_id, message_data)
return response
except Exception as e:
return str(e)
css = '''
.circle-logo {
display: inline-block;
width: 40px;
height: 40px;
border-radius: 50%;
overflow: hidden;
margin-right: 10px;
vertical-align: middle;
}
.circle-logo img {
width: 100%;
height: 100%;
object-fit: cover;
}
.response-with-logo {
display: flex;
align-items: center;
margin-bottom: 10px;
}
footer {
display: none !important;
background-color: #F8D7DA;
}
.svelte-1ed2p3z p {
font-size: 24px;
font-weight: bold;
line-height: 1.2;
color: #111;
margin: 20px 0;
}
label.svelte-1b6s6s {display: none}
div.svelte-rk35yg {display: none;}
div.progress-text.svelte-z7cif2.meta-text {display: none;}
'''
@app.get("/chat")
async def chat_ui(username: str, email: str):
gr.ChatInterface(
fn=chat_interface,
css=css,
description="Clara",
clear_btn=None,
undo_btn=None,
retry_btn=None
).launch()
return {"message": "Chat interface launched."}
if __name__ == "__main__":
data_ingestion_from_directory()
threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=8000), daemon=True).start()