Spaces:
Sleeping
Sleeping
import gradio as gr | |
from typing import List, Dict | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from transformers import pipeline | |
import os | |
from astrapy.db import AstraDB | |
from dotenv import load_dotenv | |
from huggingface_hub import login | |
# Load environment variables | |
load_dotenv() | |
# Login to Hugging Face Hub | |
login(token=os.getenv("HUGGINGFACE_API_TOKEN")) | |
class AstraDBChatbot: | |
def __init__(self): | |
# Initialize AstraDB connection | |
self.astra_db = AstraDB( | |
token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"), | |
api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT") | |
) | |
# Set your collection | |
self.collection = self.astra_db.collection(os.getenv("ASTRA_DB_COLLECTION")) | |
# Initialize the model - using a smaller model suitable for CPU | |
pipe = pipeline( | |
"text-generation", | |
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.15 | |
) | |
self.llm = HuggingFacePipeline(pipeline=pipe) | |
# Create prompt template | |
self.template = """ | |
IMPORTANT: You are a helpful assistant that provides information based on the retrieved context. | |
STRICT RULES: | |
1. Base your response ONLY on the provided context | |
2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the database." | |
3. Do not make assumptions or use external knowledge | |
4. Be concise and accurate in your responses | |
5. If quoting from the context, clearly indicate it | |
Context: {context} | |
Chat History: {chat_history} | |
Question: {question} | |
Answer:""" | |
self.prompt = ChatPromptTemplate.from_template(self.template) | |
self.chat_history = "" | |
def _search_astra(self, query: str) -> List[Dict]: | |
"""Search AstraDB for relevant documents""" | |
try: | |
# Perform vector search in AstraDB | |
results = self.collection.vector_find( | |
query, | |
limit=5 # Adjust the limit based on your needs | |
) | |
return results | |
except Exception as e: | |
print(f"Error searching AstraDB: {str(e)}") | |
return [] | |
def chat(self, query: str, history) -> str: | |
"""Process a query and return a response""" | |
try: | |
# Search AstraDB for relevant content | |
search_results = self._search_astra(query) | |
if not search_results: | |
return "I apologize, but I cannot find information about that in the database." | |
# Extract and combine relevant content from search results | |
context = "\n\n".join([result.get('content', '') for result in search_results]) | |
# Generate response using LLM | |
chain = self.prompt | self.llm | |
result = chain.invoke({ | |
"context": context, | |
"chat_history": self.chat_history, | |
"question": query | |
}) | |
self.chat_history += f"\nUser: {query}\nAI: {result}\n" | |
return result | |
except Exception as e: | |
return f"Error processing query: {str(e)}" | |
# Initialize the chatbot | |
chatbot = AstraDBChatbot() | |
# Create the Gradio interface | |
iface = gr.ChatInterface( | |
chatbot.chat, | |
title="AstraDB-powered Q&A Chatbot", | |
description="Ask questions and get answers from your AstraDB database.", | |
examples=["What information do you have about this topic?", "Can you tell me more about specific details?"], | |
theme=gr.themes.Soft() | |
) | |
# Launch the interface | |
iface.launch() |