Spaces:
Sleeping
Sleeping
import gradio as gr | |
from typing import List, Dict | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from transformers import pipeline | |
import chromadb | |
from chromadb.utils import embedding_functions | |
from sentence_transformers import SentenceTransformer | |
import os | |
class ChromaDBChatbot: | |
def __init__(self): | |
# Initialize in-memory ChromaDB | |
self.chroma_client = chromadb.Client() | |
# Initialize embedding function | |
self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="all-MiniLM-L6-v2" | |
) | |
# Create or get collection | |
self.collection = self.chroma_client.create_collection( | |
name="text_collection", | |
embedding_function=self.embedding_function | |
) | |
# Initialize the model - using a smaller model suitable for CPU | |
pipe = pipeline( | |
"text-generation", | |
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.15 | |
) | |
self.llm = HuggingFacePipeline(pipeline=pipe) | |
# Enhanced prompt templates | |
self.templates = { | |
"default": """ | |
IMPORTANT: You are a helpful assistant that provides information based on the retrieved context. | |
STRICT RULES: | |
1. Base your response ONLY on the provided context | |
2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the database." | |
3. Do not make assumptions or use external knowledge | |
4. Be concise and accurate in your responses | |
5. If quoting from the context, clearly indicate it | |
Context: {context} | |
Chat History: {chat_history} | |
Question: {question} | |
Answer:""", | |
"summary": """ | |
Create a concise summary of the following context. | |
Context: {context} | |
Key Requirements: | |
1. Highlight the main points | |
2. Keep it brief and clear | |
3. Use bullet points if appropriate | |
4. Include only information from the context | |
Summary:""", | |
"technical": """ | |
Provide a technical explanation based on the context. | |
Context: {context} | |
Question: {question} | |
Guidelines: | |
1. Focus on technical details | |
2. Explain complex concepts clearly | |
3. Use appropriate technical terminology | |
4. Provide examples if present in the context | |
Technical Explanation:""" | |
} | |
self.chat_history = "" | |
self.loaded = False | |
def load_data(self, file_path: str): | |
"""Load data into ChromaDB""" | |
if self.loaded: | |
return | |
try: | |
# Read the text file | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
# Split into chunks (512 tokens each with 50 token overlap) | |
chunk_size = 512 | |
overlap = 50 | |
chunks = [] | |
for i in range(0, len(content), chunk_size - overlap): | |
chunk = content[i:i + chunk_size] | |
chunks.append(chunk) | |
# Add documents to collection | |
self.collection.add( | |
documents=chunks, | |
ids=[f"doc_{i}" for i in range(len(chunks))] | |
) | |
self.loaded = True | |
print(f"Loaded {len(chunks)} chunks into ChromaDB") | |
except Exception as e: | |
print(f"Error loading data: {str(e)}") | |
return False | |
def _search_chroma(self, query: str) -> List[Dict]: | |
"""Search ChromaDB for relevant documents""" | |
try: | |
results = self.collection.query( | |
query_texts=[query], | |
n_results=5 | |
) | |
return [{"content": doc} for doc in results['documents'][0]] | |
except Exception as e: | |
print(f"Error searching ChromaDB: {str(e)}") | |
return [] | |
def chat(self, query: str, history) -> str: | |
"""Process a query and return a response""" | |
try: | |
if not self.loaded: | |
self.load_data('a2023-45.txt') | |
# Determine template type based on query | |
template_type = "default" | |
if any(word in query.lower() for word in ["summarize", "summary"]): | |
template_type = "summary" | |
elif any(word in query.lower() for word in ["technical", "explain", "how does"]): | |
template_type = "technical" | |
# Search ChromaDB for relevant content | |
search_results = self._search_chroma(query) | |
if not search_results: | |
return "I apologize, but I cannot find information about that in the database." | |
# Extract and combine relevant content | |
context = "\n\n".join([result['content'] for result in search_results]) | |
# Create prompt with selected template | |
prompt = ChatPromptTemplate.from_template(self.templates[template_type]) | |
# Generate response using LLM | |
chain = prompt | self.llm | |
result = chain.invoke({ | |
"context": context, | |
"chat_history": self.chat_history, | |
"question": query | |
}) | |
# Update chat history | |
self.chat_history += f"\nUser: {query}\nAI: {result}\n" | |
return result | |
except Exception as e: | |
return f"Error processing query: {str(e)}" | |
# Initialize the chatbot | |
chatbot = ChromaDBChatbot() | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=chatbot.chat, | |
inputs=[ | |
gr.Textbox( | |
label="Your Question", | |
placeholder="Ask anything about the document...", | |
lines=2 | |
), | |
gr.State([]) # For chat history | |
], | |
outputs=gr.Textbox(label="Answer", lines=10), | |
title="ChromaDB-powered Document Q&A", | |
description=""" | |
Ask questions about your document: | |
- For summaries, include words like 'summarize' or 'summary' | |
- For technical details, use words like 'technical', 'explain', 'how does' | |
- For general questions, just ask normally | |
""", | |
examples=[ | |
["Can you summarize the main points?"], | |
["What are the technical details about this topic?"], | |
["Give me a general overview of the content."], | |
], | |
theme=gr.themes.Soft() | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
demo.launch() |