Pamudu13 commited on
Commit
aaf779e
·
verified ·
1 Parent(s): cfcc518

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -154,27 +154,31 @@ def init_llm():
154
  if vector_db is None:
155
  return jsonify({'error': 'Please upload PDFs first'}), 400
156
 
 
157
  data = request.json
158
- model_name = data.get('model', 'llama') # default to llama
159
  temperature = data.get('temperature', 0.5)
160
  max_tokens = data.get('max_tokens', 4096)
161
  top_k = data.get('top_k', 3)
162
 
 
163
  if model_name not in LLM_MODELS:
164
  return jsonify({'error': 'Invalid model name'}), 400
165
 
166
  try:
 
167
  qa_chain = initialize_llmchain(
168
- LLM_MODELS[model_name],
169
- temperature,
170
- max_tokens,
171
- top_k,
172
- vector_db
173
  )
174
  return jsonify({'message': 'LLM initialized successfully'}), 200
175
  except Exception as e:
176
  return jsonify({'error': str(e)}), 500
177
 
 
178
  @app.route('/chat', methods=['POST'])
179
  def chat():
180
  """Handle chat interactions"""
 
154
  if vector_db is None:
155
  return jsonify({'error': 'Please upload PDFs first'}), 400
156
 
157
+ # Get parameters from the incoming request
158
  data = request.json
159
+ model_name = data.get('model', 'llama') # Default to 'llama' if not provided
160
  temperature = data.get('temperature', 0.5)
161
  max_tokens = data.get('max_tokens', 4096)
162
  top_k = data.get('top_k', 3)
163
 
164
+ # Ensure the model name is valid
165
  if model_name not in LLM_MODELS:
166
  return jsonify({'error': 'Invalid model name'}), 400
167
 
168
  try:
169
+ # Initialize the LLM chain with the specified parameters and the vector_db
170
  qa_chain = initialize_llmchain(
171
+ llm_model=LLM_MODELS[model_name],
172
+ temperature=temperature,
173
+ max_tokens=max_tokens,
174
+ top_k=top_k,
175
+ vector_db=vector_db # Pass vector_db to the function
176
  )
177
  return jsonify({'message': 'LLM initialized successfully'}), 200
178
  except Exception as e:
179
  return jsonify({'error': str(e)}), 500
180
 
181
+
182
  @app.route('/chat', methods=['POST'])
183
  def chat():
184
  """Handle chat interactions"""