swamisharan commited on
Commit
78d81f3
·
verified ·
1 Parent(s): 5a3c78a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -49
app.py CHANGED
@@ -1,70 +1,83 @@
1
  import gradio as gr
2
  import os
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
5
  from langchain.embeddings import SentenceTransformerEmbeddings
6
  from langchain.vectorstores import Chroma
7
  from langchain.llms.huggingface_pipeline import HuggingFacePipeline
8
  from langchain.chains import RetrievalQA
9
  from langchain.document_loaders import PDFMinerLoader
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
11
- from chromadb.config import Settings
 
12
 
13
- # Initialize Chroma settings once
14
- CHROMA_SETTINGS = Settings(
15
- chroma_db_impl='duckdb+parquet',
16
- persist_directory="db",
17
- anonymized_telemetry=False
18
- )
19
-
20
- # Initialize the Chroma database on app start (assuming the database will be initialized only once)
21
- def init_db_if_not_exists(pdf_path):
22
- try:
23
- # Check if the database exists and load it
24
- db = Chroma(persist_directory=CHROMA_SETTINGS.persist_directory, client_settings=CHROMA_SETTINGS)
25
- db.get_collection() # This line will raise an error if the collection doesn't exist
26
- except Exception:
27
- # If not, initialize the database
28
- loader = PDFMinerLoader(pdf_path)
29
- documents = loader.load()
30
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
31
- texts = text_splitter.split_documents(documents)
32
- embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
33
- db = Chroma.from_documents(texts, embeddings, persist_directory=CHROMA_SETTINGS.persist_directory)
34
- db.persist()
35
 
36
- # Load model and create pipeline once
37
  checkpoint = "MBZUAI/LaMini-Flan-T5-783M"
38
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
39
- base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.float32)
40
- llm_pipeline = HuggingFacePipeline(pipeline=pipeline("text2text-generation", model=base_model, tokenizer=tokenizer))
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- def process_answer(instruction):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
44
- vectordb = Chroma(persist_directory=CHROMA_SETTINGS.persist_directory, embedding_function=embeddings)
45
  retriever = vectordb.as_retriever()
46
- qa = RetrievalQA.from_chain_type(llm=llm_pipeline, chain_type="stuff", retriever=retriever)
47
- generated_text = qa(instruction)
48
- return generated_text["result"]
 
 
 
 
 
49
 
50
- def chatbot(pdf_file, user_question):
51
- if pdf_file: # Only initialize if a new PDF is uploaded
52
- init_db_if_not_exists(pdf_file.name)
53
- try:
54
- answer = process_answer(user_question)
55
- return answer
56
- except Exception as e:
57
- return f"An error occurred: {str(e)}"
58
 
59
- # Create Gradio Interface
60
  iface = gr.Interface(
61
- fn=chatbot,
62
- inputs=[gr.File(type="binary", label="Upload your PDF"), gr.Textbox(lines=1, label="Ask a Question")],
63
- outputs="text",
64
- title="PDF Chatbot",
65
- description="Upload a PDF and ask questions about its content.",
66
  )
67
 
68
-
69
- # Run the Gradio interface
70
- iface.launch()
 
1
  import gradio as gr
2
  import os
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ from transformers import pipeline
6
  from langchain.embeddings import SentenceTransformerEmbeddings
7
  from langchain.vectorstores import Chroma
8
  from langchain.llms.huggingface_pipeline import HuggingFacePipeline
9
  from langchain.chains import RetrievalQA
10
  from langchain.document_loaders import PDFMinerLoader
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
13
+ import chromadb
14
 
15
+ # Define Chroma Settings
16
+ CHROMA_SETTINGS = {
17
+ "chroma_db_impl": "duckdb+parquet",
18
+ "persist_directory": "db",
19
+ "anonymized_telemetry": False
20
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Load model and tokenizer
23
  checkpoint = "MBZUAI/LaMini-Flan-T5-783M"
24
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
25
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map=torch.device("cpu"), torch_dtype=torch.float32)
26
+
27
+ # Define functions
28
+ def data_ingestion(file_path):
29
+ loader = PDFMinerLoader(file_path)
30
+ documents = loader.load()
31
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
32
+ texts = text_splitter.split_documents(documents)
33
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
34
+ db = Chroma.from_documents(texts, embeddings, persist_directory=CHROMA_SETTINGS["persist_directory"])
35
+ db.persist()
36
+ print(texts)
37
+ return db
38
 
39
+ def llm_pipeline():
40
+ pipe = pipeline(
41
+ "text2text-generation",
42
+ model=base_model,
43
+ tokenizer=tokenizer,
44
+ max_length=256,
45
+ do_sample=True,
46
+ temperature=0.3,
47
+ top_p=0.95
48
+ )
49
+ local_llm = HuggingFacePipeline(pipeline=pipe)
50
+ return local_llm
51
+
52
+ def qa_llm():
53
+ llm = llm_pipeline()
54
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
55
+ vectordb = Chroma(persist_directory=CHROMA_SETTINGS["persist_directory"], embedding_function=embeddings)
56
  retriever = vectordb.as_retriever()
57
+ qa = RetrievalQA.from_chain_type(
58
+ llm=llm,
59
+ chain_type="stuff",
60
+ retriever=retriever,
61
+ return_source_documents=True
62
+ )
63
+
64
+ return qa
65
 
66
+ def process_answer(file, instruction):
67
+ # Ingest the data from the uploaded PDF
68
+ data_ingestion(file.name)
69
+ # Process the question
70
+ qa = qa_llm()
71
+ generated_text = qa(instruction)
72
+ answer = generated_text["result"]
73
+ return answer
74
 
75
+ # Define Gradio interface
76
  iface = gr.Interface(
77
+ fn=process_answer,
78
+ inputs=["file", "text"],
79
+ outputs="text"
 
 
80
  )
81
 
82
+ # Launch the interface
83
+ iface.launch()