Pamudu13 commited on
Commit
6e8e412
·
verified ·
1 Parent(s): aaf779e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -31
app.py CHANGED
@@ -67,40 +67,34 @@ def create_db(splits):
67
  return vectordb
68
 
69
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
70
- if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
71
- llm = HuggingFaceEndpoint(
72
- repo_id=llm_model,
73
- huggingfacehub_api_token = api_token,
74
- temperature = temperature,
75
- max_new_tokens = max_tokens,
76
- top_k = top_k,
77
- )
78
- else:
79
- llm = HuggingFaceEndpoint(
80
- huggingfacehub_api_token = api_token,
81
- repo_id=llm_model,
82
- temperature = temperature,
83
- max_new_tokens = max_tokens,
84
- top_k = top_k,
85
- )
86
-
87
  memory = ConversationBufferMemory(
88
  memory_key="chat_history",
89
  output_key='answer',
90
  return_messages=True
91
  )
92
 
93
- retriever=vector_db.as_retriever()
94
  qa_chain = ConversationalRetrievalChain.from_llm(
95
  llm,
96
  retriever=retriever,
97
- chain_type="stuff",
98
  memory=memory,
99
  return_source_documents=True,
100
  verbose=False,
101
  )
102
- return qa_chain
103
-
104
 
105
  def format_chat_history(message, chat_history):
106
  """Format chat history for the LLM"""
@@ -154,31 +148,27 @@ def init_llm():
154
  if vector_db is None:
155
  return jsonify({'error': 'Please upload PDFs first'}), 400
156
 
157
- # Get parameters from the incoming request
158
  data = request.json
159
- model_name = data.get('model', 'llama') # Default to 'llama' if not provided
160
- temperature = data.get('temperature', 0.5)
161
- max_tokens = data.get('max_tokens', 4096)
162
- top_k = data.get('top_k', 3)
163
 
164
- # Ensure the model name is valid
165
  if model_name not in LLM_MODELS:
166
  return jsonify({'error': 'Invalid model name'}), 400
167
 
168
  try:
169
- # Initialize the LLM chain with the specified parameters and the vector_db
170
  qa_chain = initialize_llmchain(
171
  llm_model=LLM_MODELS[model_name],
172
  temperature=temperature,
173
  max_tokens=max_tokens,
174
  top_k=top_k,
175
- vector_db=vector_db # Pass vector_db to the function
176
  )
177
  return jsonify({'message': 'LLM initialized successfully'}), 200
178
  except Exception as e:
179
  return jsonify({'error': str(e)}), 500
180
 
181
-
182
  @app.route('/chat', methods=['POST'])
183
  def chat():
184
  """Handle chat interactions"""
@@ -275,7 +265,7 @@ def finish_upload():
275
 
276
  if not current_upload['filename']:
277
  return jsonify({'error': 'No upload in progress'}), 400
278
-
279
  try:
280
  # Create temp directory if it doesn't exist
281
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
 
67
  return vectordb
68
 
69
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
70
+ """Initialize the LLM chain with correct parameter names"""
71
+ llm = HuggingFaceEndpoint(
72
+ endpoint_url="https://api-inference.huggingface.co/models/" + llm_model,
73
+ task="text-generation",
74
+ model_kwargs={
75
+ "temperature": float(temperature),
76
+ "max_length": int(max_tokens),
77
+ "top_k": int(top_k)
78
+ },
79
+ huggingfacehub_api_token=api_token
80
+ )
81
+
 
 
 
 
 
82
  memory = ConversationBufferMemory(
83
  memory_key="chat_history",
84
  output_key='answer',
85
  return_messages=True
86
  )
87
 
88
+ retriever = vector_db.as_retriever()
89
  qa_chain = ConversationalRetrievalChain.from_llm(
90
  llm,
91
  retriever=retriever,
92
+ chain_type="stuff",
93
  memory=memory,
94
  return_source_documents=True,
95
  verbose=False,
96
  )
97
+ return qa_chain
 
98
 
99
  def format_chat_history(message, chat_history):
100
  """Format chat history for the LLM"""
 
148
  if vector_db is None:
149
  return jsonify({'error': 'Please upload PDFs first'}), 400
150
 
 
151
  data = request.json
152
+ model_name = data.get('model', 'llama') # Default to 'llama'
153
+ temperature = float(data.get('temperature', 0.5))
154
+ max_tokens = int(data.get('max_tokens', 4096))
155
+ top_k = int(data.get('top_k', 3))
156
 
 
157
  if model_name not in LLM_MODELS:
158
  return jsonify({'error': 'Invalid model name'}), 400
159
 
160
  try:
 
161
  qa_chain = initialize_llmchain(
162
  llm_model=LLM_MODELS[model_name],
163
  temperature=temperature,
164
  max_tokens=max_tokens,
165
  top_k=top_k,
166
+ vector_db=vector_db
167
  )
168
  return jsonify({'message': 'LLM initialized successfully'}), 200
169
  except Exception as e:
170
  return jsonify({'error': str(e)}), 500
171
 
 
172
  @app.route('/chat', methods=['POST'])
173
  def chat():
174
  """Handle chat interactions"""
 
265
 
266
  if not current_upload['filename']:
267
  return jsonify({'error': 'No upload in progress'}), 400
268
+
269
  try:
270
  # Create temp directory if it doesn't exist
271
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)