Alimubariz124 commited on
Commit
63f5111
·
verified ·
1 Parent(s): 7423963

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -100
app.py CHANGED
@@ -1,14 +1,17 @@
1
  import os
2
- import pickle
3
  import PyPDF2
4
  import numpy as np
5
  import faiss
6
- import torch
7
- import streamlit as st
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
9
- from datasets import Dataset
10
  from sentence_transformers import SentenceTransformer
11
- from peft import LoraConfig, get_peft_model
 
 
 
 
 
 
 
12
 
13
  # Load embedding model
14
  @st.cache_resource
@@ -24,47 +27,15 @@ def parse_pdf(file):
24
  return text
25
 
26
  # Split text into chunks
27
- def split_text(text, chunk_size=500):
28
- return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
 
29
 
30
  # Create FAISS index
31
- def create_faiss_index(embeddings):
32
- dimension = embeddings.shape[1]
33
- index = faiss.IndexFlatL2(dimension)
34
- index.add(embeddings)
35
- return index
36
-
37
- # Fine-tune the model
38
- def fine_tune_model(dataset, model_name, output_dir="./fine-tuned-model"):
39
- tokenizer = AutoTokenizer.from_pretrained(model_name)
40
- model = AutoModelForCausalLM.from_pretrained(model_name)
41
-
42
- def preprocess_function(examples):
43
- inputs = [f"Question: {q} Answer: {a}" for q, a in zip(examples["question"], examples["answer"])]
44
- return tokenizer(inputs, truncation=True, padding="max_length", max_length=512)
45
-
46
- tokenized_dataset = dataset.map(preprocess_function, batched=True)
47
-
48
- training_args = TrainingArguments(
49
- output_dir=output_dir,
50
- per_device_train_batch_size=4,
51
- num_train_epochs=3,
52
- save_steps=10_000,
53
- save_total_limit=2,
54
- )
55
-
56
- trainer = Trainer(
57
- model=model,
58
- args=training_args,
59
- train_dataset=tokenized_dataset,
60
- tokenizer=tokenizer,
61
- data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
62
- )
63
-
64
- trainer.train()
65
- model.save_pretrained(output_dir)
66
- tokenizer.save_pretrained(output_dir)
67
- return output_dir
68
 
69
  # Generate response from the model
70
  def generate_response(prompt, model, tokenizer):
@@ -75,62 +46,89 @@ def generate_response(prompt, model, tokenizer):
75
 
76
  # Main Streamlit app
77
  def main():
78
- st.title("Chat with PDF using Fine-Tuned Llama Model")
79
-
80
- # Step 1: Upload PDF file
81
- uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
82
- if uploaded_file is not None:
83
- st.write("File uploaded successfully!")
84
-
85
- # Process PDF
86
- with st.spinner("Processing PDF..."):
87
- pdf_text = parse_pdf(uploaded_file)
88
- chunks = split_text(pdf_text)
 
 
 
 
 
 
 
 
 
 
 
89
  embedding_model = load_embedding_model()
90
- chunk_embeddings = embedding_model.encode(chunks)
91
- faiss_index = create_faiss_index(np.array(chunk_embeddings))
92
-
93
- st.success("PDF processed! Proceed to fine-tuning.")
94
-
95
- # Step 2: Fine-tune the model
96
- if st.button("Fine-Tune Model"):
97
- with st.spinner("Fine-tuning the model..."):
98
- # Create a dataset of question-answer pairs
99
- qa_pairs = []
100
- for chunk in chunks:
101
- qa_pairs.append({"question": "What is this about?", "answer": chunk[:100]}) # Simplified example
102
-
103
- dataset = Dataset.from_dict({
104
- "question": [pair["question"] for pair in qa_pairs],
105
- "answer": [pair["answer"] for pair in qa_pairs],
106
- })
107
-
108
- # Fine-tune the model
109
- model_name = "meta-llama/Llama-2-7b-chat-hf" # Replace with your local path
110
- fine_tuned_model_path = fine_tune_model(dataset, model_name)
111
-
112
- st.success(f"Model fine-tuned! Saved at: {fine_tuned_model_path}")
113
-
114
- # Load the fine-tuned model
115
- tokenizer = AutoTokenizer.from_pretrained(fine_tuned_model_path)
116
- model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_path, device_map="auto", torch_dtype=torch.float16)
117
-
118
- st.success("Fine-tuned model loaded! You can now ask questions.")
119
-
120
- # Step 3: Chat interface
121
- user_input = st.text_input("Ask a question about the PDF:")
122
- if user_input:
123
- with st.spinner("Generating response..."):
124
- # Retrieve relevant chunk
125
- query_embedding = embedding_model.encode([user_input])
126
- _, indices = faiss_index.search(query_embedding, k=1)
127
- relevant_chunk = chunks[indices[0][0]]
128
-
129
- # Generate response
130
- prompt = f"Context: {relevant_chunk}\nQuestion: {user_input}\nAnswer:"
131
- response = generate_response(prompt, model, tokenizer)
132
-
133
- st.write(f"**Response:** {response}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  if __name__ == "__main__":
136
  main()
 
1
  import os
2
+ import streamlit as st
3
  import PyPDF2
4
  import numpy as np
5
  import faiss
 
 
 
 
6
  from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from langchain.chains import ConversationalRetrievalChain
9
+ from langchain.embeddings import HuggingFaceEmbeddings
10
+ from langchain.vectorstores import FAISS
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain.llms import HuggingFacePipeline
13
+ from langchain.prompts import PromptTemplate
14
+ from transformers import pipeline
15
 
16
  # Load embedding model
17
  @st.cache_resource
 
27
  return text
28
 
29
  # Split text into chunks
30
+ def split_text(text, chunk_size=500, chunk_overlap=100):
31
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
32
+ return text_splitter.split_text(text)
33
 
34
  # Create FAISS index
35
+ def create_faiss_index(texts, embedding_model):
36
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
37
+ vectorstore = FAISS.from_texts(texts, embeddings)
38
+ return vectorstore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Generate response from the model
41
  def generate_response(prompt, model, tokenizer):
 
46
 
47
  # Main Streamlit app
48
  def main():
49
+ st.title("Advanced Chat with Your Document")
50
+
51
+ # Initialize session state for conversation history and documents
52
+ if "conversation_history" not in st.session_state:
53
+ st.session_state.conversation_history = []
54
+ if "vectorstore" not in st.session_state:
55
+ st.session_state.vectorstore = None
56
+
57
+ # Step 1: Upload multiple PDF files
58
+ uploaded_files = st.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)
59
+ if uploaded_files:
60
+ st.write(f"{len(uploaded_files)} file(s) uploaded successfully!")
61
+
62
+ # Process PDFs
63
+ with st.spinner("Processing PDFs..."):
64
+ all_texts = []
65
+ for uploaded_file in uploaded_files:
66
+ pdf_text = parse_pdf(uploaded_file)
67
+ chunks = split_text(pdf_text)
68
+ all_texts.extend(chunks)
69
+
70
+ # Create a unified vector database
71
  embedding_model = load_embedding_model()
72
+ st.session_state.vectorstore = create_faiss_index(all_texts, embedding_model)
73
+
74
+ st.success("PDFs processed! You can now ask questions.")
75
+
76
+ # Step 2: Chat interface
77
+ user_input = st.text_input("Ask a question about the document(s):")
78
+ if user_input:
79
+ if st.session_state.vectorstore is None:
80
+ st.error("Please upload and process documents first.")
81
+ return
82
+
83
+ with st.spinner("Generating response..."):
84
+ # Load the LLM
85
+ model_name = "meta-llama/Llama-2-7b-chat-hf" # Replace with your local path
86
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
87
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
88
+
89
+ # Set up LangChain components
90
+ retriever = st.session_state.vectorstore.as_retriever()
91
+ llm = HuggingFacePipeline(pipeline=pipeline("text-generation", model=model, tokenizer=tokenizer))
92
+
93
+ # Define a custom prompt template for Chain-of-Thought reasoning
94
+ prompt_template = """
95
+ Answer the following question based ONLY on the provided context.
96
+ If the question requires multi-step reasoning, break it down step by step.
97
+ Context: {context}
98
+ Question: {question}
99
+ Answer:
100
+ """
101
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
102
+
103
+ # Create a conversational retrieval chain
104
+ qa_chain = ConversationalRetrievalChain.from_llm(
105
+ llm=llm,
106
+ retriever=retriever,
107
+ combine_docs_chain_kwargs={"prompt": prompt},
108
+ return_source_documents=True
109
+ )
110
+
111
+ # Add conversation history
112
+ chat_history = st.session_state.conversation_history[-3:] # Last 3 interactions
113
+ result = qa_chain({"question": user_input, "chat_history": chat_history})
114
+
115
+ # Extract response and update conversation history
116
+ response = result["answer"]
117
+ st.session_state.conversation_history.append(f"User: {user_input}")
118
+ st.session_state.conversation_history.append(f"Bot: {response}")
119
+
120
+ st.write(f"**Response:** {response}")
121
+
122
+ # Display source documents (optional)
123
+ if "source_documents" in result:
124
+ st.subheader("Source Documents")
125
+ for doc in result["source_documents"]:
126
+ st.write(doc.page_content)
127
+
128
+ # Display conversation history
129
+ st.subheader("Conversation History")
130
+ for line in st.session_state.conversation_history:
131
+ st.write(line)
132
 
133
  if __name__ == "__main__":
134
  main()