Gourisankar Padihary commited on
Commit
2889c96
·
1 Parent(s): 5661370

Capability to modify the llm through UI

Browse files
app.py CHANGED
@@ -4,7 +4,8 @@ import threading
4
  import time
5
  from generator.compute_metrics import get_attributes_text
6
  from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
7
- from config import AppConfig, ConfigConstants
 
8
 
9
  def launch_gradio(config : AppConfig):
10
  """
@@ -80,17 +81,50 @@ def launch_gradio(config : AppConfig):
80
  logging.error(f"Error computing metrics: {e}")
81
  return f"An error occurred: {e}", ""
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # Define Gradio Blocks layout
84
  with gr.Blocks() as interface:
85
  interface.title = "Real Time RAG Pipeline Q&A"
86
  gr.Markdown("### Real Time RAG Pipeline Q&A") # Heading
87
 
 
 
 
 
 
 
 
88
  # Section to display LLM names
89
  with gr.Row():
90
  model_info = f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
91
  model_info += f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
92
  model_info += f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
93
- gr.Textbox(value=model_info, label="Model Information", interactive=False) # Read-only textbox
94
 
95
  # State to store response and source documents
96
  state = gr.State(value={"query": "","response": "", "source_docs": {}})
@@ -122,7 +156,19 @@ def launch_gradio(config : AppConfig):
122
  inputs=[state],
123
  outputs=[attr_output, metrics_output]
124
  )
 
 
 
 
 
 
125
 
 
 
 
 
 
 
126
  # Section to display logs
127
  with gr.Row():
128
  start_log_button = gr.Button("Start Log Update", elem_id="start_btn") # Button to start log updates
 
4
  import time
5
  from generator.compute_metrics import get_attributes_text
6
  from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
7
+ from config import AppConfig, ConfigConstants
8
+ from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm
9
 
10
  def launch_gradio(config : AppConfig):
11
  """
 
81
  logging.error(f"Error computing metrics: {e}")
82
  return f"An error occurred: {e}", ""
83
 
84
+ def reinitialize_gen_llm(gen_llm_name):
85
+ """Reinitialize the generation LLM and return updated model info."""
86
+ if gen_llm_name.strip(): # Only update if input is not empty
87
+ config.gen_llm = initialize_generation_llm(gen_llm_name)
88
+
89
+ # Return updated model information
90
+ updated_model_info = (
91
+ f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
92
+ f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
93
+ f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
94
+ )
95
+ return updated_model_info
96
+
97
+ def reinitialize_val_llm(val_llm_name):
98
+ """Reinitialize the generation LLM and return updated model info."""
99
+ if val_llm_name.strip(): # Only update if input is not empty
100
+ config.val_llm = initialize_validation_llm(val_llm_name)
101
+
102
+ # Return updated model information
103
+ updated_model_info = (
104
+ f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
105
+ f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
106
+ f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
107
+ )
108
+ return updated_model_info
109
+
110
  # Define Gradio Blocks layout
111
  with gr.Blocks() as interface:
112
  interface.title = "Real Time RAG Pipeline Q&A"
113
  gr.Markdown("### Real Time RAG Pipeline Q&A") # Heading
114
 
115
+ # Textbox for new generation LLM name
116
+ with gr.Row():
117
+ new_gen_llm_input = gr.Textbox(label="New Generation LLM Name", placeholder="Enter LLM name to update")
118
+ update_gen_llm_button = gr.Button("Update Generation LLM")
119
+ new_val_llm_input = gr.Textbox(label="New Validation LLM Name", placeholder="Enter LLM name to update")
120
+ update_val_llm_button = gr.Button("Update Validation LLM")
121
+
122
  # Section to display LLM names
123
  with gr.Row():
124
  model_info = f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
125
  model_info += f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
126
  model_info += f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
127
+ model_info_display = gr.Textbox(value=model_info, label="Model Information", interactive=False) # Read-only textbox
128
 
129
  # State to store response and source documents
130
  state = gr.State(value={"query": "","response": "", "source_docs": {}})
 
156
  inputs=[state],
157
  outputs=[attr_output, metrics_output]
158
  )
159
+
160
+ update_gen_llm_button.click(
161
+ fn=reinitialize_gen_llm,
162
+ inputs=[new_gen_llm_input],
163
+ outputs=[model_info_display] # Update the displayed model info
164
+ )
165
 
166
+ update_val_llm_button.click(
167
+ fn=reinitialize_val_llm,
168
+ inputs=[new_val_llm_input],
169
+ outputs=[model_info_display] # Update the displayed model info
170
+ )
171
+
172
  # Section to display logs
173
  with gr.Row():
174
  start_log_button = gr.Button("Start Log Update", elem_id="start_btn") # Button to start log updates
config.py CHANGED
@@ -1,7 +1,7 @@
1
 
2
  class ConfigConstants:
3
  # Constants related to datasets and models
4
- DATA_SET_NAMES = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']
5
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2"
6
  RE_RANKER_MODEL_NAME = 'cross-encoder/ms-marco-electra-base'
7
  GENERATION_MODEL_NAME = 'mixtral-8x7b-32768'
 
1
 
2
  class ConfigConstants:
3
  # Constants related to datasets and models
4
+ DATA_SET_NAMES = ['covidqa', 'cuad']#, 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']
5
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2"
6
  RE_RANKER_MODEL_NAME = 'cross-encoder/ms-marco-electra-base'
7
  GENERATION_MODEL_NAME = 'mixtral-8x7b-32768'
generator/initialize_llm.py CHANGED
@@ -2,18 +2,22 @@ import logging
2
  import os
3
  from langchain_groq import ChatGroq
4
 
5
- from config import ConfigConstants
6
-
7
- def initialize_generation_llm():
8
  os.environ["GROQ_API_KEY"] = ""
9
- model_name = ConfigConstants.GENERATION_MODEL_NAME
 
10
  llm = ChatGroq(model=model_name, temperature=0.7)
 
11
  logging.info(f'Generation LLM {model_name} initialized')
 
12
  return llm
13
 
14
- def initialize_validation_llm():
15
  os.environ["GROQ_API_KEY"] = ""
16
- model_name = ConfigConstants.VALIDATION_MODEL_NAME
 
17
  llm = ChatGroq(model=model_name, temperature=0.7)
 
18
  logging.info(f'Validation LLM {model_name} initialized')
 
19
  return llm
 
2
  import os
3
  from langchain_groq import ChatGroq
4
 
5
+ def initialize_generation_llm(input_model_name):
 
 
6
  os.environ["GROQ_API_KEY"] = ""
7
+
8
+ model_name = input_model_name
9
  llm = ChatGroq(model=model_name, temperature=0.7)
10
+ llm.name = model_name
11
  logging.info(f'Generation LLM {model_name} initialized')
12
+
13
  return llm
14
 
15
+ def initialize_validation_llm(input_model_name):
16
  os.environ["GROQ_API_KEY"] = ""
17
+
18
+ model_name = input_model_name
19
  llm = ChatGroq(model=model_name, temperature=0.7)
20
+ llm.name = model_name
21
  logging.info(f'Validation LLM {model_name} initialized')
22
+
23
  return llm
main.py CHANGED
@@ -44,10 +44,10 @@ def main():
44
  logging.info("Documents embedded")
45
 
46
  # Initialize the Generation LLM
47
- gen_llm = initialize_generation_llm()
48
 
49
  # Initialize the Validation LLM
50
- val_llm = initialize_validation_llm()
51
 
52
  #Compute RMSE and AUC-ROC for entire dataset
53
  #Enable below code for calculation
@@ -55,7 +55,7 @@ def main():
55
  #compute_rmse_auc_roc_metrics(gen_llm, val_llm, datasets[data_set_name], vector_store, 10)
56
 
57
  # Launch the Gradio app
58
- config = AppConfig(vector_store= vector_store, gen_llm= gen_llm, val_llm= val_llm)
59
  launch_gradio(config)
60
 
61
  logging.info("Finished!!!")
 
44
  logging.info("Documents embedded")
45
 
46
  # Initialize the Generation LLM
47
+ gen_llm = initialize_generation_llm(ConfigConstants.GENERATION_MODEL_NAME)
48
 
49
  # Initialize the Validation LLM
50
+ val_llm = initialize_validation_llm(ConfigConstants.VALIDATION_MODEL_NAME)
51
 
52
  #Compute RMSE and AUC-ROC for entire dataset
53
  #Enable below code for calculation
 
55
  #compute_rmse_auc_roc_metrics(gen_llm, val_llm, datasets[data_set_name], vector_store, 10)
56
 
57
  # Launch the Gradio app
58
+ config = AppConfig(vector_store= vector_store, gen_llm = gen_llm, val_llm = val_llm)
59
  launch_gradio(config)
60
 
61
  logging.info("Finished!!!")
retriever/retrieve_documents.py CHANGED
@@ -8,7 +8,7 @@ def retrieve_top_k_documents(vector_store, query, top_k=5):
8
  documents = vector_store.similarity_search(query, k=top_k)
9
  logging.info(f"Top {top_k} documents reterived for query")
10
 
11
- documents = rerank_documents(query, documents)
12
 
13
  return documents
14
 
 
8
  documents = vector_store.similarity_search(query, k=top_k)
9
  logging.info(f"Top {top_k} documents reterived for query")
10
 
11
+ #documents = rerank_documents(query, documents)
12
 
13
  return documents
14