Spaces:
Running
Running
from langchain_core.prompts import PromptTemplate | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
import gradio as gr | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import numpy as np | |
from langchain_ollama import OllamaLLM | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFacePipeline | |
from load_document import load_data | |
from split_document import split_docs | |
from embed_docs import embed_docs | |
from retrieve import retrieve | |
from datetime import datetime | |
# from js import js | |
# from theme import theme | |
import os | |
import glob | |
from fastapi import FastAPI, Query, Request | |
from pydantic import BaseModel | |
import uvicorn | |
app = FastAPI(title="Know The Law", description="A FastAPI application for legal assistance using AI.") | |
vector_store_path = "/home/user/VectorStoreDB" | |
index_name = "faiss_index" | |
full_index_path = os.path.join(vector_store_path, index_name) | |
# # Create the embedder with a specific model | |
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# # Initialize our speech pipeline | |
# transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en", device="cpu") | |
def fetch_doc(): | |
# Adjust the path as needed, e.g., './' for current directory | |
pdf_files = glob.glob("Document/*.pdf") | |
# If you want to include subdirectories: | |
# pdf_files = glob.glob("**/*.pdf", recursive=True) | |
return pdf_files | |
# # Define llm | |
hf_token = os.environ.get("HF_TOKEN").strip() # Ensure to set your Hugging Face token in the environment variable HF_TOKEN | |
# #llm = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3", device="cpu", use_auth_token=hf_token, token=hf_token) | |
# #llm = OllamaLLM(model="mistral:7b-instruct", base_url="http://host.docker.internal:11434") | |
model_id = "google/gemma-2b-it" | |
# # Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", torch_dtype="auto", token=hf_token) | |
# # Create text generation pipeline | |
hf_pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
llm = HuggingFacePipeline(pipeline=hf_pipe) | |
pdf_files = fetch_doc() #Fetch Dataset | |
chunks = None | |
loaded_docs = [] | |
# just query if it exists | |
if not os.path.exists(full_index_path): | |
for doc in pdf_files: | |
print(f"Loading.....{doc}") | |
docs = load_data(doc) #Load Dataset | |
loaded_docs.append(docs) | |
final_docs = [item for sublist in loaded_docs for item in sublist] # Flatten the list | |
chunks = split_docs(final_docs, embedder=embedder) #Split Document | |
saved_vector = embed_docs(chunks, embedder=embedder) #Embed Document | |
retrieved = retrieve(saved_vector) # Retrieve simimlar docs | |
# Define the prompt template | |
prompt = """ | |
You are The Law Assistant, an AI trained to help Nigerians understand their legal rights and obligations. Using the provided context below, answer user questions related to Nigerian law. | |
Instructions: | |
1. Base your responses strictly on the given context or verified legal sources. | |
2. If the answer is not in the context and you're unsure, respond with: "I don't know based on the available information." Do not fabricate or speculate. | |
3. Keep your answers clear, concise, and jargon-free. | |
4. Always cite the legal source(s) or reference(s) you used (e.g., constitution section, legal act, court ruling). | |
Context: {context} | |
Question: {{question}} | |
Helpful Answer:""" | |
QA_CHAIN_PROMPT = PromptTemplate.from_template(template=prompt) | |
# Create document prompt | |
document_prompt = PromptTemplate( | |
input_variables=["page_content", "source"], | |
template="Context:\ncontent:{page_content}\nsource:{source}", | |
) | |
# Create the stuff documents chain | |
combine_docs_chain = create_stuff_documents_chain( | |
llm, | |
QA_CHAIN_PROMPT, | |
document_prompt=document_prompt | |
) | |
# Create the retrieval chain | |
qa_chain = create_retrieval_chain( | |
retriever=retrieved, | |
combine_docs_chain=combine_docs_chain | |
) | |
class QueryRequest(BaseModel): | |
question: str | |
def home(): | |
return {"message": "Welcome to the Know The Law API. Use POST /query to ask legal questions."} | |
def respond(query: QueryRequest): | |
# Invoke the chain with the question | |
question = query.question | |
result = qa_chain.invoke({"input":question}) | |
# Return the answer | |
return {"answer": result['answer']} | |