from langchain_core.prompts import PromptTemplate from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain import gradio as gr from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM import numpy as np from langchain_ollama import OllamaLLM from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.llms import HuggingFacePipeline from load_document import load_data from split_document import split_docs from embed_docs import embed_docs from retrieve import retrieve from datetime import datetime # from js import js # from theme import theme import os import glob from fastapi import FastAPI, Query, Request from pydantic import BaseModel import uvicorn app = FastAPI(title="Know The Law", description="A FastAPI application for legal assistance using AI.") vector_store_path = "/home/user/VectorStoreDB" index_name = "faiss_index" full_index_path = os.path.join(vector_store_path, index_name) # # Create the embedder with a specific model embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") # # Initialize our speech pipeline # transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en", device="cpu") def fetch_doc(): # Adjust the path as needed, e.g., './' for current directory pdf_files = glob.glob("Document/*.pdf") # If you want to include subdirectories: # pdf_files = glob.glob("**/*.pdf", recursive=True) return pdf_files # # Define llm hf_token = os.environ.get("HF_TOKEN").strip() # Ensure to set your Hugging Face token in the environment variable HF_TOKEN # #llm = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3", device="cpu", use_auth_token=hf_token, token=hf_token) # #llm = OllamaLLM(model="mistral:7b-instruct", base_url="http://host.docker.internal:11434") model_id = "google/gemma-2b-it" # # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", torch_dtype="auto", token=hf_token) # # Create text generation pipeline hf_pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True ) llm = HuggingFacePipeline(pipeline=hf_pipe) pdf_files = fetch_doc() #Fetch Dataset chunks = None loaded_docs = [] # just query if it exists if not os.path.exists(full_index_path): for doc in pdf_files: print(f"Loading.....{doc}") docs = load_data(doc) #Load Dataset loaded_docs.append(docs) final_docs = [item for sublist in loaded_docs for item in sublist] # Flatten the list chunks = split_docs(final_docs, embedder=embedder) #Split Document saved_vector = embed_docs(chunks, embedder=embedder) #Embed Document retrieved = retrieve(saved_vector) # Retrieve simimlar docs # Define the prompt template prompt = """ You are The Law Assistant, an AI trained to help Nigerians understand their legal rights and obligations. Using the provided context below, answer user questions related to Nigerian law. Instructions: 1. Base your responses strictly on the given context or verified legal sources. 2. If the answer is not in the context and you're unsure, respond with: "I don't know based on the available information." Do not fabricate or speculate. 3. Keep your answers clear, concise, and jargon-free. 4. Always cite the legal source(s) or reference(s) you used (e.g., constitution section, legal act, court ruling). Context: {context} Question: {{question}} Helpful Answer:""" QA_CHAIN_PROMPT = PromptTemplate.from_template(template=prompt) # Create document prompt document_prompt = PromptTemplate( input_variables=["page_content", "source"], template="Context:\ncontent:{page_content}\nsource:{source}", ) # Create the stuff documents chain combine_docs_chain = create_stuff_documents_chain( llm, QA_CHAIN_PROMPT, document_prompt=document_prompt ) # Create the retrieval chain qa_chain = create_retrieval_chain( retriever=retrieved, combine_docs_chain=combine_docs_chain ) class QueryRequest(BaseModel): question: str @app.get("/") def home(): return {"message": "Welcome to the Know The Law API. Use POST /query to ask legal questions."} @app.post("/query") def respond(query: QueryRequest): # Invoke the chain with the question question = query.question result = qa_chain.invoke({"input":question}) # Return the answer return {"answer": result['answer']}