Spaces:
Running
Running
from langchain_core.prompts import PromptTemplate | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
# import gradio as gr | |
import numpy as np | |
from langchain_ollama import OllamaLLM | |
from langchain_huggingface import HuggingFaceEmbeddings | |
# from langchain_community.llms import HuggingFacePipeline | |
from load_document import load_data | |
from split_document import split_docs | |
from embed_docs import embed_docs | |
from retrieve import retrieve | |
from datetime import datetime | |
from js import js | |
from theme import theme | |
import os | |
import glob | |
from fastapi import FastAPI, Query, Request | |
from pydantic import BaseModel | |
import uvicorn | |
app = FastAPI(title="Know The Law", description="A FastAPI application for legal assistance using AI.") | |
vector_store_path = "/home/user/VectorStoreDB" | |
index_name = "faiss_index" | |
full_index_path = os.path.join(vector_store_path, index_name) | |
# # Create the embedder with a specific model | |
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# # Initialize our speech pipeline | |
# transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en", device="cpu") | |
def fetch_doc(): | |
# Adjust the path as needed, e.g., './' for current directory | |
pdf_files = glob.glob("*.pdf") | |
return pdf_files | |
# # Define llm | |
#hf_token = os.environ.get("HF_TOKEN").strip() # Ensure to set your Hugging Face token in the environment variable HF_TOKEN | |
llm = OllamaLLM(model="llama3.2", base_url="http://localhost:11434") | |
pdf_files = fetch_doc() #Fetch Dataset | |
chunks = None | |
loaded_docs = [] | |
# just query if it exists | |
if not os.path.exists(full_index_path): | |
for doc in pdf_files: | |
print(f"Loading.....{doc}") | |
docs = load_data(doc) #Load Dataset | |
loaded_docs.append(docs) | |
final_docs = [item for sublist in loaded_docs for item in sublist] # Flatten the list | |
chunks = split_docs(final_docs, embedder=embedder) #Split Document | |
saved_vector = embed_docs(chunks, embedder=embedder) #Embed Document | |
retrieved = retrieve(saved_vector) # Retrieve simimlar docs | |
# Define the prompt template | |
prompt = """ | |
You are The Law Assistant, an AI trained to help Nigerians understand their legal rights and obligations. Using the provided context below, answer user questions related to Nigerian law. | |
Instructions: | |
1. Base your responses strictly on the given context or verified legal sources. | |
2. If the answer is not in the context and you're unsure, respond with: "I don't know based on the available information." Do not fabricate or speculate. | |
3. Keep your answers clear, concise, and jargon-free. | |
4. Always cite the legal source(s) or reference(s) you used (e.g., constitution section, legal act, court ruling). | |
Context: {context} | |
Question: {{question}} | |
Helpful Answer:""" | |
QA_CHAIN_PROMPT = PromptTemplate.from_template(template=prompt) | |
# Create document prompt | |
document_prompt = PromptTemplate( | |
input_variables=["page_content", "source"], | |
template="Context:\ncontent:{page_content}\nsource:{source}", | |
) | |
# Create the stuff documents chain | |
combine_docs_chain = create_stuff_documents_chain( | |
llm, | |
QA_CHAIN_PROMPT, | |
document_prompt=document_prompt | |
) | |
# Create the retrieval chain | |
qa_chain = create_retrieval_chain( | |
retriever=retrieved, | |
combine_docs_chain=combine_docs_chain | |
) | |
class QueryRequest(BaseModel): | |
question: str | |
def home(): | |
print("Testing Ollama LLM response...") | |
try: | |
response = llm.invoke("What is the capital of Nigeria?") | |
print("Response:", response) | |
except Exception as e: | |
print("Failed to invoke model:", e) | |
return {"message": "Welcome to the Know The Law API. Use POST /query to ask legal questions."} | |
def respond(query: QueryRequest): | |
question = query.question | |
try: | |
result = qa_chain.invoke({"input": question}) | |
return {"answer": result['answer']} | |
except Exception as e: | |
# Log the exception in detail | |
import traceback | |
traceback.print_exc() | |
return {"error": str(e)} | |