Pamudu13 commited on
Commit
c005795
·
verified ·
1 Parent(s): 4707b28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -66,23 +66,28 @@ def create_db(splits):
66
  vectordb = FAISS.from_documents(splits, embeddings)
67
  return vectordb
68
 
69
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
70
- """Initialize the LLM chain"""
 
71
  llm = HuggingFaceEndpoint(
72
  repo_id=llm_model,
73
  huggingfacehub_api_token=api_token,
74
  temperature=temperature,
75
- max_new_tokens=max_tokens,
76
- top_k=top_k,
77
  )
78
 
 
79
  memory = ConversationBufferMemory(
80
  memory_key="chat_history",
81
  output_key='answer',
82
  return_messages=True
83
  )
84
 
 
85
  retriever = vector_db.as_retriever()
 
 
86
  qa_chain = ConversationalRetrievalChain.from_llm(
87
  llm,
88
  retriever=retriever,
@@ -93,6 +98,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
93
  )
94
  return qa_chain
95
 
 
96
  def format_chat_history(message, chat_history):
97
  """Format chat history for the LLM"""
98
  formatted_chat_history = []
 
66
  vectordb = FAISS.from_documents(splits, embeddings)
67
  return vectordb
68
 
69
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, api_token):
70
+ """Initialize the LLM chain with a HuggingFace model"""
71
+ # Use valid Hugging Face parameters. `max_length` might be the correct field instead of `max_new_tokens`
72
  llm = HuggingFaceEndpoint(
73
  repo_id=llm_model,
74
  huggingfacehub_api_token=api_token,
75
  temperature=temperature,
76
+ max_length=max_tokens, # Adjusted from max_new_tokens to max_length
77
+ # Remove top_k as it may not be valid or handled differently
78
  )
79
 
80
+ # Set up memory for conversation
81
  memory = ConversationBufferMemory(
82
  memory_key="chat_history",
83
  output_key='answer',
84
  return_messages=True
85
  )
86
 
87
+ # Ensure vector_db is used as a retriever
88
  retriever = vector_db.as_retriever()
89
+
90
+ # Initialize ConversationalRetrievalChain using LLM and the retriever
91
  qa_chain = ConversationalRetrievalChain.from_llm(
92
  llm,
93
  retriever=retriever,
 
98
  )
99
  return qa_chain
100
 
101
+
102
  def format_chat_history(message, chat_history):
103
  """Format chat history for the LLM"""
104
  formatted_chat_history = []