arjunanand13 commited on
Commit
715be0e
·
verified ·
1 Parent(s): 90d7570

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -18
app.py CHANGED
@@ -140,23 +140,18 @@ def evaluate_rag_pipeline(qa_chain, dataset):
140
  return avg_results
141
 
142
  # Initialize langchain LLM chain
143
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
144
- if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
145
- llm = HuggingFaceEndpoint(
146
- repo_id=llm_model,
147
- huggingfacehub_api_token=api_token,
148
- temperature=temperature,
149
- max_new_tokens=max_tokens,
150
- top_k=top_k,
151
- )
152
- else:
153
- llm = HuggingFaceEndpoint(
154
- huggingfacehub_api_token=api_token,
155
- repo_id=llm_model,
156
- temperature=temperature,
157
- max_new_tokens=max_tokens,
158
- top_k=top_k,
159
- )
160
 
161
  memory = ConversationBufferMemory(
162
  memory_key="chat_history",
@@ -173,7 +168,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
173
  return_source_documents=True,
174
  verbose=False,
175
  )
176
- return qa_chain
177
 
178
  def initialize_database(list_file_obj, splitting_strategy, db_choice, progress=gr.Progress()):
179
  list_file_path = [x.name for x in list_file_obj if x is not None]
 
140
  return avg_results
141
 
142
  # Initialize langchain LLM chain
143
+ def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
144
+ # Get the full model name from the index
145
+ llm_model = list_llm[llm_choice]
146
+
147
+ llm = HuggingFaceEndpoint(
148
+ repo_id=llm_model,
149
+ huggingfacehub_api_token=api_token,
150
+ temperature=temperature,
151
+ max_new_tokens=max_tokens,
152
+ top_k=top_k,
153
+ model=llm_model # Add model parameter
154
+ )
 
 
 
 
 
155
 
156
  memory = ConversationBufferMemory(
157
  memory_key="chat_history",
 
168
  return_source_documents=True,
169
  verbose=False,
170
  )
171
+ return qa_chain, "LLM initialized successfully!"
172
 
173
  def initialize_database(list_file_obj, splitting_strategy, db_choice, progress=gr.Progress()):
174
  list_file_path = [x.name for x in list_file_obj if x is not None]