PDFgpt / app.py
swamisharan's picture
Update app.py
609dcb3 verified
raw
history blame
2.68 kB
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import pipeline
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.document_loaders import PDFMinerLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
import chromadb
import gradio as gr
from gradio.components import File
# Define Chroma Settings
CHROMA_SETTINGS = {
"chroma_db_impl": "duckdb+parquet",
"persist_directory": "db",
"anonymized_telemetry": False
}
# Load model and tokenizer
checkpoint = "MBZUAI/LaMini-Flan-T5-783M"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map=torch.device("cpu"), torch_dtype=torch.float32)
# Define functions
def data_ingestion(file_path):
loader = PDFMinerLoader(file_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
texts = text_splitter.split_documents(documents)
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
db = Chroma.from_documents(texts, embeddings, persist_directory=CHROMA_SETTINGS["persist_directory"])
db.persist()
print(texts)
return db
def llm_pipeline():
pipe = pipeline(
"text2text-generation",
model=base_model,
tokenizer=tokenizer,
max_length=256,
do_sample=True,
temperature=0.3,
top_p=0.95
)
local_llm = HuggingFacePipeline(pipeline=pipe)
return local_llm
def qa_llm():
llm = llm_pipeline()
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = Chroma(persist_directory=CHROMA_SETTINGS["persist_directory"], embedding_function=embeddings)
retriever = vectordb.as_retriever()
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True
)
return qa
def process_answer(file):
db = data_ingestion(file)
question = input("Please enter your question: ")
qa = qa_llm()
generated_text = qa(question)
answer = generated_text["result"]
return answer
# Create a Gradio interface
demo = gr.Interface(
fn=process_answer,
inputs=File(type="pdf"),
outputs="text",
title="Chatbot",
description="Please enter your question:"
)
# Launch the Gradio interface
demo.launch()