fern / app.py
Srinivasulu kethanaboina
Update app.py
2405bb9 verified
raw
history blame
5.39 kB
from dotenv import load_dotenv
import gradio as gr
import os
from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import random
import uuid
import datetime
from gradio_client import Client
# Initialize Gradio Client
client = Client("srinukethanaboina/SRUNU")
def select_random_name():
names = ['Clara', 'Lily']
return random.choice(names)
# Load environment variables
load_dotenv()
# Configure the Llama index settings
Settings.llm = HuggingFaceInferenceAPI(
model_name="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct",
context_window=3000,
token=os.getenv("HF_TOKEN"),
max_new_tokens=512,
generate_kwargs={"temperature": 0.1},
)
Settings.embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en-v1.5"
)
# Define the directory for persistent storage and data
PERSIST_DIR = "db"
PDF_DIRECTORY = 'data' # Directory containing PDFs
# Ensure directories exist
os.makedirs(PDF_DIRECTORY, exist_ok=True)
os.makedirs(PERSIST_DIR, exist_ok=True)
# Variable to store current chat conversation
current_chat_history = []
def data_ingestion_from_directory():
# Use SimpleDirectoryReader on the directory containing the PDF files
documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
storage_context = StorageContext.from_defaults()
index = VectorStoreIndex.from_documents(documents)
index.storage_context.persist(persist_dir=PERSIST_DIR)
def handle_query(query):
chat_text_qa_msgs = [
(
"user",
"""
You are the Clara Redfernstech chatbot. Your goal is to provide accurate, professional, and helpful answers to user queries based on the company's data. Always ensure your responses are clear and concise. give response within 10-15 words only
{context_str}
Question:
{query_str}
"""
)
]
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
# Load index from storage
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
# Use chat history to enhance response
context_str = ""
for past_query, response in reversed(current_chat_history):
if past_query.strip():
context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
answer = query_engine.query(query)
if hasattr(answer, 'response'):
response = answer.response
elif isinstance(answer, dict) and 'response' in answer:
response = answer['response']
else:
response = "Sorry, I couldn't find an answer."
# Update current chat history
current_chat_history.append((query, response))
return response
# Example usage: Process PDF ingestion from directory
print("Processing PDF ingestion from directory:", PDF_DIRECTORY)
data_ingestion_from_directory()
def predict(message, req):
logo_html = '''
<div class="circle-logo">
<img src="https://rb.gy/8r06eg" alt="FernAi">
</div>
'''
# Use the gradio_client API to process the chat history and IP address
response = client.predict(
ip_address=req.client.host, # Extract IP address from request
chat_history=message,
api_name="/predict"
)
response_with_logo = f'<div class="response-with-logo">{logo_html}<div class="response-text">{response}</div></div>'
return response_with_logo
# Define your Gradio chat interface function (replace with your actual logic)
def chat_interface(message, history, request: gr.Request):
try:
# Process the user message and generate a response
response = predict(message, request)
# Capture the message data
message_data = {
"sender": "user",
"message": message,
"response": response,
"timestamp": datetime.datetime.now().isoformat() # Use a library like datetime
}
# Return the bot response
return response
except Exception as e:
return str(e)
# Custom CSS for styling
css = '''
.circle-logo {
display: inline-block;
width: 40px;
height: 40px;
border-radius: 50%;
overflow: hidden;
margin-right: 10px;
vertical-align: middle;
}
.circle-logo img {
width: 100%;
height: 100%;
object-fit: cover;
}
.response-with-logo {
display: flex;
align-items: center;
margin-bottom: 10px;
}
footer {
display: none !important;
background-color: #F8D7DA;
}
.svelte-1ed2p3z p {
font-size: 24px;
font-weight: bold;
line-height: 1.2;
color: #111;
margin: 20px 0;
}
label.svelte-1b6s6s {display: none}
div.svelte-rk35yg {display: none;}
div.progress-text.svelte-z7cif2.meta-text {display: none;}
'''
# Launch the Gradio interface
# Launch the Gradio interface with verbose error reporting
gr.ChatInterface(chat_interface,
css=css,
description="Clara",
clear_btn=None, undo_btn=None, retry_btn=None,
).launch(show_error=True)