Spaces:
Sleeping
Sleeping
import gradio as gr | |
from typing import List, Dict | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from sentence_transformers import SentenceTransformer | |
import chromadb | |
from chromadb.utils import embedding_functions | |
import numpy as np | |
from tqdm import tqdm | |
import os | |
from huggingface_hub import login | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Login to Hugging Face Hub if token is available | |
if os.getenv("HUGGINGFACE_API_TOKEN"): | |
login(token=os.getenv("HUGGINGFACE_API_TOKEN")) | |
class EnhancedChatbot: | |
def __init__(self): | |
# Initialize ChromaDB | |
self.chroma_client = chromadb.Client() | |
# Initialize embedding model using sentence-transformers | |
self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="all-MiniLM-L6-v2" | |
) | |
# Create collection with cosine similarity | |
self.collection = self.chroma_client.create_collection( | |
name="text_collection", | |
embedding_function=self.embedding_function, | |
metadata={"hnsw:space": "cosine"} | |
) | |
# Initialize the LLM with 8-bit quantization for efficiency | |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
load_in_8bit=True, | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.15, | |
do_sample=True | |
) | |
self.llm = HuggingFacePipeline(pipeline=pipe) | |
# Enhanced prompt templates with specific use cases | |
self.templates = { | |
"default": """ | |
You are a knowledgeable assistant providing accurate information based on the given context. | |
GUIDELINES: | |
1. Use ONLY the provided context | |
2. If information is not in context, say "I don't have enough information" | |
3. Be concise and clear | |
4. Use markdown formatting for better readability | |
5. If quoting, use proper citation format | |
Context: {context} | |
Chat History: {chat_history} | |
Question: {question} | |
Response:""", | |
"summary": """ | |
Create a comprehensive summary of the provided context. | |
Context: {context} | |
REQUIREMENTS: | |
1. Structure the summary with clear headings | |
2. Use bullet points for key information | |
3. Highlight important concepts | |
4. Maintain factual accuracy | |
Summary:""", | |
"technical": """ | |
Provide a detailed technical analysis of the context. | |
Context: {context} | |
Question: {question} | |
GUIDELINES: | |
1. Focus on technical specifications | |
2. Explain complex concepts clearly | |
3. Use appropriate technical terminology | |
4. Include relevant examples from context | |
5. Structure the response logically | |
Technical Analysis:""", | |
"comparative": """ | |
Compare and analyze different aspects from the context. | |
Context: {context} | |
Question: {question} | |
APPROACH: | |
1. Identify key points for comparison | |
2. Analyze similarities and differences | |
3. Present balanced viewpoints | |
4. Use tables or lists for clarity | |
Comparison:""" | |
} | |
self.chat_history = [] | |
self.loaded = False | |
def load_data(self, file_path: str, chunk_size: int = 512, overlap: int = 50): | |
"""Load and index data with progress bar""" | |
if self.loaded: | |
return True | |
try: | |
# Read the text file | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
# Create chunks with overlap | |
chunks = [] | |
for i in range(0, len(content), chunk_size - overlap): | |
chunk = content[i:i + chunk_size] | |
chunks.append(chunk) | |
# Add documents to collection with progress bar | |
for i, chunk in tqdm(enumerate(chunks), desc="Loading chunks", total=len(chunks)): | |
self.collection.add( | |
documents=[chunk], | |
ids=[f"chunk_{i}"], | |
metadatas=[{"source": file_path, "chunk_id": i}] | |
) | |
self.loaded = True | |
return True | |
except Exception as e: | |
print(f"Error loading data: {str(e)}") | |
return False | |
def _search_documents(self, query: str, n_results: int = 5) -> List[Dict]: | |
"""Search for relevant documents""" | |
try: | |
results = self.collection.query( | |
query_texts=[query], | |
n_results=n_results, | |
include=["documents", "metadatas", "distances"] | |
) | |
return [ | |
{ | |
"content": doc, | |
"metadata": meta, | |
"similarity": 1 - dist # Convert distance to similarity | |
} | |
for doc, meta, dist in zip( | |
results['documents'][0], | |
results['metadatas'][0], | |
results['distances'][0] | |
) | |
] | |
except Exception as e: | |
print(f"Search error: {str(e)}") | |
return [] | |
def _select_template(self, query: str) -> str: | |
"""Select appropriate template based on query content""" | |
query_lower = query.lower() | |
if any(word in query_lower for word in ["summarize", "summary", "overview"]): | |
return "summary" | |
elif any(word in query_lower for word in ["technical", "explain how", "how does"]): | |
return "technical" | |
elif any(word in query_lower for word in ["compare", "difference", "versus", "vs"]): | |
return "comparative" | |
return "default" | |
def chat(self, query: str, history) -> str: | |
"""Process query and generate response""" | |
try: | |
if not self.loaded: | |
if not self.load_data('a2023-45.txt'): | |
return "Error: Failed to load document data." | |
# Search for relevant content | |
search_results = self._search_documents(query) | |
if not search_results: | |
return "I apologize, but I couldn't find relevant information in the database." | |
# Prepare context with similarity scores | |
context_parts = [] | |
for result in search_results: | |
context_parts.append( | |
f"[Similarity: {result['similarity']:.2f}]\n{result['content']}" | |
) | |
context = "\n\n".join(context_parts) | |
# Select and use appropriate template | |
template_type = self._select_template(query) | |
prompt = ChatPromptTemplate.from_template(self.templates[template_type]) | |
# Generate response | |
chain = prompt | self.llm | |
response = chain.invoke({ | |
"context": context, | |
"chat_history": "\n".join([f"{h[0]}: {h[1]}" for h in self.chat_history[-3:]]), | |
"question": query | |
}) | |
# Update chat history | |
self.chat_history.append(("User", query)) | |
self.chat_history.append(("Assistant", response)) | |
return response | |
except Exception as e: | |
return f"Error processing query: {str(e)}" | |
# Initialize chatbot | |
chatbot = EnhancedChatbot() | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=chatbot.chat, | |
inputs=[ | |
gr.Textbox( | |
label="Your Question", | |
placeholder="Ask anything about the document...", | |
lines=2 | |
), | |
gr.State([]) # For chat history | |
], | |
outputs=gr.Textbox(label="Answer", lines=10), | |
title="π€ Enhanced Document Q&A System", | |
description=""" | |
### Advanced Document Question-Answering System | |
**Available Query Types:** | |
- π **General Questions**: Just ask normally | |
- π **Summaries**: Include words like "summarize" or "overview" | |
- π§ **Technical Details**: Use words like "technical" or "explain how" | |
- π **Comparisons**: Ask to "compare" or use "versus" | |
*The system will automatically select the best response format based on your question.* | |
""", | |
examples=[ | |
["Can you summarize the main points of the document?"], | |
["What are the technical details about the implementation?"], | |
["Compare the different approaches mentioned in the text."], | |
["What are the key concepts discussed?"] | |
], | |
theme=gr.themes.Soft() | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
demo.launch() |