Spaces:
Running
Running
import os | |
import torch | |
import streamlit as st | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
from langchain_community.llms import HuggingFacePipeline | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from dotenv import load_dotenv | |
# Set Streamlit page configuration | |
st.set_page_config(page_title="Chat with Notes and AI", page_icon=":books:", layout="wide") | |
# Load environment variables | |
load_dotenv() | |
# Dolly-v2-3b model pipeline | |
def load_pipeline(): | |
model_name = "databricks/dolly-v2-3b" | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", trust_remote_code=True) | |
# Load model with offload folder for disk storage of weights | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, # Use float16 for GPU, float32 for CPU | |
device_map="auto", # Automatically map model to available devices (e.g., GPU if available) | |
trust_remote_code=True, | |
offload_folder="./offload_weights" # Folder to store offloaded weights | |
) | |
# Return text-generation pipeline | |
return pipeline( | |
task="text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto", | |
return_full_text=True | |
) | |
# Initialize Dolly pipeline | |
generate_text = load_pipeline() | |
# Create a HuggingFace pipeline wrapper for LangChain | |
hf_pipeline = HuggingFacePipeline(pipeline=generate_text) | |
# Template for instruction-only prompts | |
prompt = PromptTemplate( | |
input_variables=["instruction"], | |
template="{instruction}" | |
) | |
# Template for prompts with context | |
prompt_with_context = PromptTemplate( | |
input_variables=["instruction", "context"], | |
template="{instruction}\n\nInput:\n{context}" | |
) | |
# Create LLM chains | |
llm_chain = LLMChain(llm=hf_pipeline, prompt=prompt) | |
llm_context_chain = LLMChain(llm=hf_pipeline, prompt=prompt_with_context) | |
# Extracting text from .txt files | |
def get_text_files_content(folder): | |
text = "" | |
for filename in os.listdir(folder): | |
if filename.endswith('.txt'): | |
with open(os.path.join(folder, filename), 'r', encoding='utf-8') as file: | |
text += file.read() + "\n" | |
return text | |
# Converting text to chunks | |
def get_chunks(raw_text): | |
from langchain.text_splitter import CharacterTextSplitter | |
text_splitter = CharacterTextSplitter( | |
separator="\n", | |
chunk_size=1000, # Reduced chunk size for faster processing | |
chunk_overlap=200, # Smaller overlap for efficiency | |
length_function=len | |
) | |
chunks = text_splitter.split_text(raw_text) | |
return chunks | |
# Using Hugging Face embeddings model and FAISS to create vectorstore | |
def get_vectorstore(chunks): | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={'device': 'cpu'} # Ensure embeddings use CPU | |
) | |
vectorstore = FAISS.from_texts(texts=chunks, embedding=embeddings) | |
return vectorstore | |
# Generating response from user queries | |
def handle_question(question, vectorstore=None): | |
if vectorstore: | |
# Reduce the number of retrieved chunks for faster processing | |
documents = vectorstore.similarity_search(question, k=2) | |
context = "\n".join([doc.page_content for doc in documents]) | |
# Limit context to 1000 characters to speed up model inference | |
context = context[:1000] | |
if context: | |
result_with_context = llm_context_chain.invoke({"instruction": question, "context": context}) | |
return result_with_context | |
# Fallback to instruction-only chain if no context is found | |
return llm_chain.invoke({"instruction": question}) | |
def main(): | |
st.title("Chat with Notes :books:") | |
# Initialize session state | |
if "vectorstore" not in st.session_state: | |
st.session_state.vectorstore = None | |
# Folder for subject data | |
data_folder = "data" | |
# Subject selection | |
subjects = [ | |
"A Trumped World", "Agri Tax in Punjab", "Assad's Fall in Syria", "Elusive National Unity", "Europe and Trump 2.0", | |
"Going Down with Democracy", "Indonesia's Pancasila Philosophy", "Pakistan in Choppy Waters", | |
"Pakistan's Semiconductor Ambitions", "Preserving Pakistan's Cultural Heritage", "Tackling Informal Economy", | |
"Technical Education in Pakistan", "The Case for Solidarity Levies", "The Decline of the Sole Superpower", | |
"The Power of Big Oil", "Trump 2.0 and Pakistan's Emerging Foreign Policy", "Trump and the World 2.0", | |
"Trump vs BRICS", "US-China Trade War", "War on Humanity", "Women's Suppression in Afghanistan" | |
] | |
subject_folders = {subject: os.path.join(data_folder, subject.replace(' ', '_')) for subject in subjects} | |
selected_subject = st.sidebar.selectbox("Select a Subject:", subjects) | |
# Process data folder for vectorstore | |
subject_folder_path = subject_folders[selected_subject] | |
raw_text = "" | |
if os.path.exists(subject_folder_path): | |
raw_text = get_text_files_content(subject_folder_path) | |
if raw_text: | |
text_chunks = get_chunks(raw_text) | |
vectorstore = get_vectorstore(text_chunks) | |
st.session_state.vectorstore = vectorstore | |
else: | |
st.warning("No content found for the selected subject.") | |
else: | |
st.error(f"Folder not found for {selected_subject}.") | |
# Display preview of notes | |
if raw_text: | |
st.subheader("Preview of Notes") | |
st.text_area("Preview Content:", value=raw_text[:2000], height=300, disabled=True) # Show a snippet of the notes | |
# Chat interface | |
st.subheader("Ask Your Question") | |
question = st.text_input("Ask a question about your selected subject:") | |
if question: | |
if st.session_state.vectorstore: | |
response = handle_question(question, st.session_state.vectorstore) | |
st.subheader("Answer:") | |
st.write(response.get("text", "No response found.")) | |
else: | |
st.warning("Please load the content for the selected subject before asking a question.") | |
if __name__ == '__main__': | |
main() | |