advisor / app.py
veerukhannan's picture
Create app.py
5f5f8de verified
raw
history blame
3.92 kB
import gradio as gr
from typing import List, Dict
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from transformers import pipeline
import os
from astrapy.db import AstraDB
from dotenv import load_dotenv
from huggingface_hub import login
# Load environment variables
load_dotenv()
# Login to Hugging Face Hub
login(token=os.getenv("HUGGINGFACE_API_TOKEN"))
class AstraDBChatbot:
def __init__(self):
# Initialize AstraDB connection
self.astra_db = AstraDB(
token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT")
)
# Set your collection
self.collection = self.astra_db.collection(os.getenv("ASTRA_DB_COLLECTION"))
# Initialize the model - using a smaller model suitable for CPU
pipe = pipeline(
"text-generation",
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
max_new_tokens=512,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.15
)
self.llm = HuggingFacePipeline(pipeline=pipe)
# Create prompt template
self.template = """
IMPORTANT: You are a helpful assistant that provides information based on the retrieved context.
STRICT RULES:
1. Base your response ONLY on the provided context
2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the database."
3. Do not make assumptions or use external knowledge
4. Be concise and accurate in your responses
5. If quoting from the context, clearly indicate it
Context: {context}
Chat History: {chat_history}
Question: {question}
Answer:"""
self.prompt = ChatPromptTemplate.from_template(self.template)
self.chat_history = ""
def _search_astra(self, query: str) -> List[Dict]:
"""Search AstraDB for relevant documents"""
try:
# Perform vector search in AstraDB
results = self.collection.vector_find(
query,
limit=5 # Adjust the limit based on your needs
)
return results
except Exception as e:
print(f"Error searching AstraDB: {str(e)}")
return []
def chat(self, query: str, history) -> str:
"""Process a query and return a response"""
try:
# Search AstraDB for relevant content
search_results = self._search_astra(query)
if not search_results:
return "I apologize, but I cannot find information about that in the database."
# Extract and combine relevant content from search results
context = "\n\n".join([result.get('content', '') for result in search_results])
# Generate response using LLM
chain = self.prompt | self.llm
result = chain.invoke({
"context": context,
"chat_history": self.chat_history,
"question": query
})
self.chat_history += f"\nUser: {query}\nAI: {result}\n"
return result
except Exception as e:
return f"Error processing query: {str(e)}"
# Initialize the chatbot
chatbot = AstraDBChatbot()
# Create the Gradio interface
iface = gr.ChatInterface(
chatbot.chat,
title="AstraDB-powered Q&A Chatbot",
description="Ask questions and get answers from your AstraDB database.",
examples=["What information do you have about this topic?", "Can you tell me more about specific details?"],
theme=gr.themes.Soft()
)
# Launch the interface
iface.launch()