Pamudu13 commited on
Commit
c062c17
·
verified ·
1 Parent(s): c005795

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -18
app.py CHANGED
@@ -66,37 +66,40 @@ def create_db(splits):
66
  vectordb = FAISS.from_documents(splits, embeddings)
67
  return vectordb
68
 
69
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, api_token):
70
- """Initialize the LLM chain with a HuggingFace model"""
71
- # Use valid Hugging Face parameters. `max_length` might be the correct field instead of `max_new_tokens`
72
- llm = HuggingFaceEndpoint(
73
- repo_id=llm_model,
74
- huggingfacehub_api_token=api_token,
75
- temperature=temperature,
76
- max_length=max_tokens, # Adjusted from max_new_tokens to max_length
77
- # Remove top_k as it may not be valid or handled differently
78
- )
79
-
80
- # Set up memory for conversation
 
 
 
 
 
 
81
  memory = ConversationBufferMemory(
82
  memory_key="chat_history",
83
  output_key='answer',
84
  return_messages=True
85
  )
86
 
87
- # Ensure vector_db is used as a retriever
88
- retriever = vector_db.as_retriever()
89
-
90
- # Initialize ConversationalRetrievalChain using LLM and the retriever
91
  qa_chain = ConversationalRetrievalChain.from_llm(
92
  llm,
93
  retriever=retriever,
94
- chain_type="stuff",
95
  memory=memory,
96
  return_source_documents=True,
97
  verbose=False,
98
  )
99
- return qa_chain
100
 
101
 
102
  def format_chat_history(message, chat_history):
 
66
  vectordb = FAISS.from_documents(splits, embeddings)
67
  return vectordb
68
 
69
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
70
+ if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
71
+ llm = HuggingFaceEndpoint(
72
+ repo_id=llm_model,
73
+ huggingfacehub_api_token = api_token,
74
+ temperature = temperature,
75
+ max_new_tokens = max_tokens,
76
+ top_k = top_k,
77
+ )
78
+ else:
79
+ llm = HuggingFaceEndpoint(
80
+ huggingfacehub_api_token = api_token,
81
+ repo_id=llm_model,
82
+ temperature = temperature,
83
+ max_new_tokens = max_tokens,
84
+ top_k = top_k,
85
+ )
86
+
87
  memory = ConversationBufferMemory(
88
  memory_key="chat_history",
89
  output_key='answer',
90
  return_messages=True
91
  )
92
 
93
+ retriever=vector_db.as_retriever()
 
 
 
94
  qa_chain = ConversationalRetrievalChain.from_llm(
95
  llm,
96
  retriever=retriever,
97
+ chain_type="stuff",
98
  memory=memory,
99
  return_source_documents=True,
100
  verbose=False,
101
  )
102
+ return qa_chain
103
 
104
 
105
  def format_chat_history(message, chat_history):