AbenzaFran commited on
Commit
bfec1c8
·
verified ·
1 Parent(s): b2205c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -12
app.py CHANGED
@@ -4,35 +4,60 @@ from dotenv import load_dotenv
4
  load_dotenv()
5
 
6
  from langchain.agents.openai_assistant import OpenAIAssistantRunnable
7
- from langchain.agents import AgentExecutor
8
-
9
  from langchain.schema import HumanMessage, AIMessage
10
 
11
- import gradio
12
 
 
13
  api_key = os.getenv('OPENAI_API_KEY')
14
- extractor_agent = os.getenv('ASSISTANT_ID_SOLUTION_SPECIFIER_A')
15
-
16
- extractor_llm = OpenAIAssistantRunnable(assistant_id=extractor_agent, api_key=api_key, as_agent=True)
17
-
18
-
 
 
 
 
 
 
 
19
  def remove_citation(text):
20
  # Define the regex pattern to match the citation format 【number†text】
21
  pattern = r"【\d+†\w+】"
22
  # Replace the pattern with an empty string
23
  return re.sub(pattern, "📚", text)
24
 
25
- def predict(message, history):
 
 
 
 
 
 
26
  history_langchain_format = []
27
  for human, ai in history:
28
  history_langchain_format.append(HumanMessage(content=human))
29
  history_langchain_format.append(AIMessage(content=ai))
30
  history_langchain_format.append(HumanMessage(content=message))
 
 
31
  gpt_response = extractor_llm.invoke({"content": message})
32
  output = gpt_response.return_values["output"]
33
  non_cited_output = remove_citation(output)
34
  return non_cited_output
35
 
36
- #gradio.Markdown("Click [here](https://www.google.com) to visit Google.")
37
- chat = gradio.ChatInterface(predict, title="Solution Specifier A", description="testing for the time being")
38
- chat.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
4
  load_dotenv()
5
 
6
  from langchain.agents.openai_assistant import OpenAIAssistantRunnable
 
 
7
  from langchain.schema import HumanMessage, AIMessage
8
 
9
+ import gradio as gr
10
 
11
+ # Load API key and assistant IDs
12
  api_key = os.getenv('OPENAI_API_KEY')
13
+ extractor_agents = {
14
+ "Solution Specifier A": os.getenv('ASSISTANT_ID_SOLUTION_SPECIFIER_A'),
15
+ "Solution Specifier B": os.getenv('ASSISTANT_ID_SOLUTION_SPECIFIER_B'),
16
+ "Solution Specifier C": os.getenv('ASSISTANT_ID_SOLUTION_SPECIFIER_C'),
17
+ "Solution Specifier D": os.getenv('ASSISTANT_ID_SOLUTION_SPECIFIER_D'),
18
+ }
19
+
20
+ # Function to create a new extractor LLM instance
21
+ def get_extractor_llm(agent_id):
22
+ return OpenAIAssistantRunnable(assistant_id=agent_id, api_key=api_key, as_agent=True)
23
+
24
+ # Utility function to remove citations
25
  def remove_citation(text):
26
  # Define the regex pattern to match the citation format 【number†text】
27
  pattern = r"【\d+†\w+】"
28
  # Replace the pattern with an empty string
29
  return re.sub(pattern, "📚", text)
30
 
31
+ # Prediction function
32
+ def predict(message, history, selected_agent):
33
+ # Get the extractor LLM for the selected agent
34
+ agent_id = extractor_agents[selected_agent]
35
+ extractor_llm = get_extractor_llm(agent_id)
36
+
37
+ # Prepare the chat history
38
  history_langchain_format = []
39
  for human, ai in history:
40
  history_langchain_format.append(HumanMessage(content=human))
41
  history_langchain_format.append(AIMessage(content=ai))
42
  history_langchain_format.append(HumanMessage(content=message))
43
+
44
+ # Get the response
45
  gpt_response = extractor_llm.invoke({"content": message})
46
  output = gpt_response.return_values["output"]
47
  non_cited_output = remove_citation(output)
48
  return non_cited_output
49
 
50
+ # Define the Gradio interface
51
+ def app_interface():
52
+ dropdown = gr.Dropdown(choices=list(extractor_agents.keys()), value="Solution Specifier A", label="Choose Extractor Agent")
53
+ chat = gr.ChatInterface(
54
+ fn=lambda message, history, selected_agent: predict(message, history, selected_agent),
55
+ inputs=[dropdown],
56
+ title="Solution Specifier Chat",
57
+ description="Test with different solution specifiers"
58
+ )
59
+ return chat
60
+
61
+ # Launch the app
62
+ chat_interface = app_interface()
63
+ chat_interface.launch(share=True)