advisor / app.py
veerukhannan's picture
Update app.py
2b58400 verified
raw
history blame
6.78 kB
import gradio as gr
from typing import List, Dict
from langchain_huggingface import HuggingFacePipeline # Fixed import
from langchain_core.prompts import ChatPromptTemplate
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import chromadb
from chromadb.utils import embedding_functions
import torch
import os
class LegalChatbot:
def __init__(self):
print("Initializing Legal Chatbot...")
# Initialize ChromaDB
self.chroma_client = chromadb.Client()
# Initialize embedding function
self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2",
device="cpu"
)
# Create collection
self.collection = self.chroma_client.create_collection(
name="text_collection",
embedding_function=self.embedding_function,
metadata={"hnsw:space": "cosine"}
)
# Initialize the model - using a smaller model suitable for CPU
pipe = pipeline(
"text-generation",
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
max_new_tokens=512,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.15,
device="cpu"
)
self.llm = HuggingFacePipeline(pipeline=pipe)
# Create prompt template
self.template = """
IMPORTANT: You are a helpful assistant that provides information about the Bharatiya Nyaya Sanhita, 2023 based on the retrieved context.
STRICT RULES:
1. Base your response ONLY on the provided context
2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the database."
3. Do not make assumptions or use external knowledge
4. Be concise and accurate in your responses
5. If quoting from the context, clearly indicate it
Context: {context}
Chat History: {chat_history}
Question: {question}
Answer:"""
self.prompt = ChatPromptTemplate.from_template(self.template)
self.chat_history = ""
self.initialized = False
def _initialize_database(self) -> bool:
"""Initialize the database with document content"""
try:
if self.initialized:
return True
print("Loading documents into database...")
# Read the main text file
with open('a2023-45.txt', 'r', encoding='utf-8') as f:
text_content = f.read()
# Read the index file
with open('index.txt', 'r', encoding='utf-8') as f:
index_lines = f.readlines()
# Create chunks
chunk_size = 512
chunks = []
for i in range(0, len(text_content), chunk_size):
chunk = text_content[i:i + chunk_size]
chunks.append(chunk)
# Add documents in batches
batch_size = 50
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
batch_ids = [f"doc_{j}" for j in range(i, i + len(batch))]
batch_metadata = [{
"index": index_lines[j].strip() if j < len(index_lines) else f"Chunk {j+1}",
"chunk_number": j
} for j in range(i, i + len(batch))]
self.collection.add(
documents=batch,
ids=batch_ids,
metadatas=batch_metadata
)
self.initialized = True
return True
except Exception as e:
print(f"Error initializing database: {str(e)}")
return False
def _search_database(self, query: str) -> List[Dict]:
"""Search the database for relevant documents"""
try:
results = self.collection.query(
query_texts=[query],
n_results=3,
include=["documents", "metadatas", "distances"]
)
return [
{
"content": doc,
"metadata": meta,
"score": 1 - dist
}
for doc, meta, dist in zip(
results['documents'][0],
results['metadatas'][0],
results['distances'][0]
)
]
except Exception as e:
print(f"Error searching database: {str(e)}")
return []
def chat(self, query: str, history) -> str:
"""Process a query and return a response"""
try:
# Initialize database if needed
if not self.initialized and not self._initialize_database():
return "Error: Unable to initialize the database. Please try again."
# Search for relevant content
search_results = self._search_database(query)
if not search_results:
return "I apologize, but I cannot find information about that in the database."
# Extract and combine relevant content
context = "\n\n".join([
f"[Section {r['metadata']['index']}]\n{r['content']}"
for r in search_results
])
# Generate response using LLM
chain = self.prompt | self.llm
result = chain.invoke({
"context": context,
"chat_history": self.chat_history,
"question": query
})
# Update chat history
self.chat_history += f"\nUser: {query}\nAI: {result}\n"
return result
except Exception as e:
return f"Error processing query: {str(e)}"
# Initialize the chatbot
chatbot = LegalChatbot()
# Create the Gradio interface
iface = gr.ChatInterface(
chatbot.chat,
title="Bharatiya Nyaya Sanhita, 2023 - Legal Assistant",
description="Ask questions about the Bharatiya Nyaya Sanhita, 2023. The system will initialize on your first query.",
examples=[
"What is criminal conspiracy?",
"What are the punishments for corruption?",
"Explain the concept of culpable homicide",
"What constitutes theft under the act?"
],
theme=gr.themes.Soft()
)
# Launch the interface
if __name__ == "__main__":
iface.launch(
share=False,
show_error=True
)