Deepakraj2006 commited on
Commit
f8407d3
·
verified ·
1 Parent(s): e05207c

Update worker.py

Browse files
Files changed (1) hide show
  1. worker.py +22 -13
worker.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
  import torch
3
- from langchain.chains import RetrievalQA
4
  from langchain_community.embeddings import HuggingFaceEmbeddings
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain_text_splitters import RecursiveCharacterTextSplitter
7
  from langchain_community.vectorstores import Chroma
8
- from langchain_community.llms import HuggingFaceHub
 
9
 
10
  # Check for GPU availability
11
  DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -13,11 +14,12 @@ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
13
  # Global variables
14
  conversation_retrieval_chain = None
15
  chat_history = []
16
- llm_hub = None
17
  embeddings = None
18
 
 
19
  def init_llm():
20
- global llm_hub, embeddings
21
 
22
  # Ensure API key is set in Hugging Face Spaces
23
  hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
@@ -25,17 +27,19 @@ def init_llm():
25
  raise ValueError("HUGGINGFACEHUB_API_TOKEN is not set in environment variables.")
26
 
27
  model_id = "tiiuae/falcon-7b-instruct"
28
- llm_hub = HuggingFaceHub(repo_id=model_id, model_kwargs={"temperature": 0.1, "max_new_tokens": 600, "max_length": 600})
 
29
 
30
  embeddings = HuggingFaceEmbeddings(
31
  model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": DEVICE}
32
  )
33
 
 
34
  def process_document(document_path):
35
  global conversation_retrieval_chain
36
 
37
  # Ensure LLM and embeddings are initialized
38
- if not llm_hub or not embeddings:
39
  init_llm()
40
 
41
  loader = PyPDFLoader(document_path)
@@ -44,22 +48,27 @@ def process_document(document_path):
44
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
45
  texts = text_splitter.split_documents(documents)
46
 
47
- db = Chroma.from_documents(texts, embedding=embeddings, persist_directory="./chroma_db")
 
 
 
 
 
48
 
49
- conversation_retrieval_chain = RetrievalQA.from_chain_type(
50
- llm=llm_hub,
51
- chain_type="stuff",
52
- retriever=db.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}),
53
- return_source_documents=False
54
  )
55
 
 
56
  def process_prompt(prompt):
57
  global conversation_retrieval_chain, chat_history
58
 
59
  if not conversation_retrieval_chain:
60
  return "No document has been processed yet. Please upload a PDF first."
61
 
62
- output = conversation_retrieval_chain({"query": prompt, "chat_history": chat_history})
63
  answer = output["answer"]
64
 
65
  chat_history.append((prompt, answer))
 
1
  import os
2
  import torch
3
+ from langchain.chains import ConversationalRetrievalChain
4
  from langchain_community.embeddings import HuggingFaceEmbeddings
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain_text_splitters import RecursiveCharacterTextSplitter
7
  from langchain_community.vectorstores import Chroma
8
+ from langchain_community.llms import HuggingFacePipeline
9
+ from transformers import pipeline
10
 
11
  # Check for GPU availability
12
  DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
 
14
  # Global variables
15
  conversation_retrieval_chain = None
16
  chat_history = []
17
+ llm_pipeline = None
18
  embeddings = None
19
 
20
+
21
  def init_llm():
22
+ global llm_pipeline, embeddings
23
 
24
  # Ensure API key is set in Hugging Face Spaces
25
  hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
 
27
  raise ValueError("HUGGINGFACEHUB_API_TOKEN is not set in environment variables.")
28
 
29
  model_id = "tiiuae/falcon-7b-instruct"
30
+ hf_pipeline = pipeline("text-generation", model=model_id, device=DEVICE)
31
+ llm_pipeline = HuggingFacePipeline(pipeline=hf_pipeline)
32
 
33
  embeddings = HuggingFaceEmbeddings(
34
  model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": DEVICE}
35
  )
36
 
37
+
38
  def process_document(document_path):
39
  global conversation_retrieval_chain
40
 
41
  # Ensure LLM and embeddings are initialized
42
+ if not llm_pipeline or not embeddings:
43
  init_llm()
44
 
45
  loader = PyPDFLoader(document_path)
 
48
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
49
  texts = text_splitter.split_documents(documents)
50
 
51
+ # Load or create ChromaDB
52
+ persist_directory = "./chroma_db"
53
+ if os.path.exists(persist_directory):
54
+ db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
55
+ else:
56
+ db = Chroma.from_documents(texts, embedding=embeddings, persist_directory=persist_directory)
57
 
58
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={'k': 6})
59
+
60
+ conversation_retrieval_chain = ConversationalRetrievalChain.from_llm(
61
+ llm=llm_pipeline, retriever=retriever
 
62
  )
63
 
64
+
65
  def process_prompt(prompt):
66
  global conversation_retrieval_chain, chat_history
67
 
68
  if not conversation_retrieval_chain:
69
  return "No document has been processed yet. Please upload a PDF first."
70
 
71
+ output = conversation_retrieval_chain({"question": prompt, "chat_history": chat_history})
72
  answer = output["answer"]
73
 
74
  chat_history.append((prompt, answer))