Pijush2023 commited on
Commit
f073604
·
verified ·
1 Parent(s): 0c05143

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -3
app.py CHANGED
@@ -99,12 +99,20 @@ def initialize_phi_model():
99
 
100
  def initialize_gpt_model():
101
  return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o')
 
 
 
 
 
 
102
 
103
 
104
 
105
  # Initialize all models
106
  phi_pipe = initialize_phi_model()
107
  gpt_model = initialize_gpt_model()
 
 
108
 
109
 
110
 
@@ -140,7 +148,7 @@ vectorstore = PineconeVectorStore(index_name=index_name, embedding=embeddings)
140
  retriever = vectorstore.as_retriever(search_kwargs={'k': 5})
141
 
142
  chat_model = ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o')
143
-
144
  conversational_memory = ConversationBufferWindowMemory(
145
  memory_key='chat_history',
146
  k=10,
@@ -353,7 +361,9 @@ def generate_bot_response(history, choice, retrieval_mode, model_choice):
353
  return
354
 
355
  # Select the model
356
- selected_model = chat_model if model_choice == "LM-1" else phi_pipe
 
 
357
 
358
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
359
  history[-1][1] = ""
@@ -425,7 +435,9 @@ def generate_bot_response(history, choice, retrieval_mode, model_choice):
425
  return
426
 
427
  # Select the model
428
- selected_model = chat_model if model_choice == "LM-1" else phi_pipe
 
 
429
 
430
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
431
  history[-1][1] = ""
@@ -528,7 +540,28 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
528
  response = qa_chain({"query": message})
529
  logging.debug(f"LM-1 response: {response}")
530
  return response['result'], extract_addresses(response['result'])
 
 
 
 
 
 
531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  elif selected_model == phi_pipe:
533
  logging.debug("Selected model: LM-2")
534
  retriever = phi_retriever
 
99
 
100
  def initialize_gpt_model():
101
  return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o')
102
+
103
+ def initialize_gpt4o_mini_model():
104
+ return ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o-mini')
105
+
106
+
107
+
108
 
109
 
110
 
111
  # Initialize all models
112
  phi_pipe = initialize_phi_model()
113
  gpt_model = initialize_gpt_model()
114
+ gpt4o_mini_model = initialize_gpt4o_mini_model()
115
+
116
 
117
 
118
 
 
148
  retriever = vectorstore.as_retriever(search_kwargs={'k': 5})
149
 
150
  chat_model = ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o')
151
+ chat_model1 = ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'], temperature=0, model='gpt-4o-mini')
152
  conversational_memory = ConversationBufferWindowMemory(
153
  memory_key='chat_history',
154
  k=10,
 
361
  return
362
 
363
  # Select the model
364
+ # selected_model = chat_model if model_choice == "LM-1" else phi_pipe
365
+ selected_model = chat_model if model_choice == "LM-1" else (chat_model1 if model_choice == "LM-3" else phi_pipe)
366
+
367
 
368
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
369
  history[-1][1] = ""
 
435
  return
436
 
437
  # Select the model
438
+ # selected_model = chat_model if model_choice == "LM-1" else phi_pipe
439
+ selected_model = chat_model if model_choice == "LM-1" else (chat_model1 if model_choice == "LM-3" else phi_pipe)
440
+
441
 
442
  response, addresses = generate_answer(history[-1][0], choice, retrieval_mode, selected_model)
443
  history[-1][1] = ""
 
540
  response = qa_chain({"query": message})
541
  logging.debug(f"LM-1 response: {response}")
542
  return response['result'], extract_addresses(response['result'])
543
+
544
+ elif selected_model == chat_model1:
545
+ logging.debug("Selected model: LM-3")
546
+ retriever = gpt_retriever
547
+ context = retriever.get_relevant_documents(message)
548
+ logging.debug(f"Retrieved context: {context}")
549
 
550
+ prompt = prompt_template.format(context=context, question=message)
551
+ logging.debug(f"Generated prompt: {prompt}")
552
+
553
+ qa_chain = RetrievalQA.from_chain_type(
554
+ llm=chat_model1,
555
+ chain_type="stuff",
556
+ retriever=retriever,
557
+ chain_type_kwargs={"prompt": prompt_template}
558
+ )
559
+ response = qa_chain({"query": message})
560
+ logging.debug(f"LM-3 response: {response}")
561
+ return response['result'], extract_addresses(response['result'])
562
+
563
+
564
+
565
  elif selected_model == phi_pipe:
566
  logging.debug("Selected model: LM-2")
567
  retriever = phi_retriever