DHEIVER commited on
Commit
1acc22b
·
verified ·
1 Parent(s): 15f3912

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -41,23 +41,28 @@ def create_db(splits):
41
  return vectordb
42
 
43
 
 
 
 
44
  # Initialize langchain LLM chain
45
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
46
  if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
47
  llm = HuggingFaceEndpoint(
48
  repo_id=llm_model,
49
- huggingfacehub_api_token = api_token,
50
- temperature = temperature,
51
- max_new_tokens = max_tokens,
52
- top_k = top_k,
 
53
  )
54
  else:
55
  llm = HuggingFaceEndpoint(
56
- huggingfacehub_api_token = api_token,
57
- repo_id=llm_model,
58
- temperature = temperature,
59
- max_new_tokens = max_tokens,
60
- top_k = top_k,
 
61
  )
62
 
63
  memory = ConversationBufferMemory(
@@ -66,7 +71,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
66
  return_messages=True
67
  )
68
 
69
- retriever=vector_db.as_retriever()
70
  qa_chain = ConversationalRetrievalChain.from_llm(
71
  llm,
72
  retriever=retriever,
@@ -76,7 +81,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
76
  verbose=False,
77
  )
78
  return qa_chain
79
-
80
  # Initialize database
81
  def initialize_database(list_file_obj, progress=gr.Progress()):
82
  # Create a list of documents (when valid)
 
41
  return vectordb
42
 
43
 
44
+ # Initialize langchain LLM chain
45
+ from langchain_community.llms import HuggingFaceEndpoint
46
+
47
  # Initialize langchain LLM chain
48
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
49
  if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
50
  llm = HuggingFaceEndpoint(
51
  repo_id=llm_model,
52
+ huggingfacehub_api_token=api_token,
53
+ temperature=temperature,
54
+ max_new_tokens=max_tokens,
55
+ top_k=top_k,
56
+ task="text-generation" # Explicitly specify the task type
57
  )
58
  else:
59
  llm = HuggingFaceEndpoint(
60
+ huggingfacehub_api_token=api_token,
61
+ repo_id=llm_model,
62
+ temperature=temperature,
63
+ max_new_tokens=max_tokens,
64
+ top_k=top_k,
65
+ task="text-generation" # Explicitly specify the task type
66
  )
67
 
68
  memory = ConversationBufferMemory(
 
71
  return_messages=True
72
  )
73
 
74
+ retriever = vector_db.as_retriever()
75
  qa_chain = ConversationalRetrievalChain.from_llm(
76
  llm,
77
  retriever=retriever,
 
81
  verbose=False,
82
  )
83
  return qa_chain
84
+
85
  # Initialize database
86
  def initialize_database(list_file_obj, progress=gr.Progress()):
87
  # Create a list of documents (when valid)