DeLaw_ollama / app.py
Sadiksmart0's picture
Upload 11 files
74d8f71 verified
raw
history blame
4.63 kB
from langchain_core.prompts import PromptTemplate
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import numpy as np
from langchain_ollama import OllamaLLM
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from load_document import load_data
from split_document import split_docs
from embed_docs import embed_docs
from retrieve import retrieve
from datetime import datetime
# from js import js
# from theme import theme
import os
import glob
from fastapi import FastAPI, Query, Request
from pydantic import BaseModel
import uvicorn
app = FastAPI(title="Know The Law", description="A FastAPI application for legal assistance using AI.")
vector_store_path = "/home/user/VectorStoreDB"
index_name = "faiss_index"
full_index_path = os.path.join(vector_store_path, index_name)
# # Create the embedder with a specific model
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# # Initialize our speech pipeline
# transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en", device="cpu")
def fetch_doc():
# Adjust the path as needed, e.g., './' for current directory
pdf_files = glob.glob("Document/*.pdf")
# If you want to include subdirectories:
# pdf_files = glob.glob("**/*.pdf", recursive=True)
return pdf_files
# # Define llm
hf_token = os.environ.get("HF_TOKEN").strip() # Ensure to set your Hugging Face token in the environment variable HF_TOKEN
# #llm = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3", device="cpu", use_auth_token=hf_token, token=hf_token)
# #llm = OllamaLLM(model="mistral:7b-instruct", base_url="http://host.docker.internal:11434")
model_id = "google/gemma-2b-it"
# # Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", torch_dtype="auto", token=hf_token)
# # Create text generation pipeline
hf_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True
)
llm = HuggingFacePipeline(pipeline=hf_pipe)
pdf_files = fetch_doc() #Fetch Dataset
chunks = None
loaded_docs = []
# just query if it exists
if not os.path.exists(full_index_path):
for doc in pdf_files:
print(f"Loading.....{doc}")
docs = load_data(doc) #Load Dataset
loaded_docs.append(docs)
final_docs = [item for sublist in loaded_docs for item in sublist] # Flatten the list
chunks = split_docs(final_docs, embedder=embedder) #Split Document
saved_vector = embed_docs(chunks, embedder=embedder) #Embed Document
retrieved = retrieve(saved_vector) # Retrieve simimlar docs
# Define the prompt template
prompt = """
You are The Law Assistant, an AI trained to help Nigerians understand their legal rights and obligations. Using the provided context below, answer user questions related to Nigerian law.
Instructions:
1. Base your responses strictly on the given context or verified legal sources.
2. If the answer is not in the context and you're unsure, respond with: "I don't know based on the available information." Do not fabricate or speculate.
3. Keep your answers clear, concise, and jargon-free.
4. Always cite the legal source(s) or reference(s) you used (e.g., constitution section, legal act, court ruling).
Context: {context}
Question: {{question}}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template=prompt)
# Create document prompt
document_prompt = PromptTemplate(
input_variables=["page_content", "source"],
template="Context:\ncontent:{page_content}\nsource:{source}",
)
# Create the stuff documents chain
combine_docs_chain = create_stuff_documents_chain(
llm,
QA_CHAIN_PROMPT,
document_prompt=document_prompt
)
# Create the retrieval chain
qa_chain = create_retrieval_chain(
retriever=retrieved,
combine_docs_chain=combine_docs_chain
)
class QueryRequest(BaseModel):
question: str
@app.get("/")
def home():
return {"message": "Welcome to the Know The Law API. Use POST /query to ask legal questions."}
@app.post("/query")
def respond(query: QueryRequest):
# Invoke the chain with the question
question = query.question
result = qa_chain.invoke({"input":question})
# Return the answer
return {"answer": result['answer']}