AbenzaFran commited on
Commit
1734d74
·
verified ·
1 Parent(s): a7cca91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -16
app.py CHANGED
@@ -3,35 +3,57 @@ import re
3
  from dotenv import load_dotenv
4
  load_dotenv()
5
 
6
- from langchain.agents.openai_assistant import OpenAIAssistantRunnable
7
- from langchain.agents import AgentExecutor
8
 
 
9
  from langchain.schema import HumanMessage, AIMessage
10
 
11
- import gradio
12
-
13
  api_key = os.getenv('OPENAI_API_KEY')
14
  extractor_agent = os.getenv('ASSISTANT_ID_SOLUTION_SPECIFIER_A')
15
 
16
- extractor_llm = OpenAIAssistantRunnable(assistant_id=extractor_agent, api_key=api_key, as_agent=True)
 
 
 
 
 
 
 
 
 
17
 
18
  def remove_citation(text):
19
- # Define the regex pattern to match the citation format 【number†text】
20
  pattern = r"【\d+†\w+】"
21
- # Replace the pattern with an empty string
22
  return re.sub(pattern, "📚", text)
23
 
24
  def predict(message, history):
25
- history_langchain_format = []
26
- for human, ai in history:
27
- history_langchain_format.append(HumanMessage(content=human))
28
- history_langchain_format.append(AIMessage(content=ai))
29
- history_langchain_format.append(HumanMessage(content=message))
30
- gpt_response = extractor_llm.invoke({"content": message})
31
- output = gpt_response.return_values["output"]
 
 
 
 
 
 
 
 
 
 
 
32
  non_cited_output = remove_citation(output)
 
 
33
  return non_cited_output
34
 
35
- #gradio.Markdown("Click [here](https://www.google.com) to visit Google.")
36
- chat = gradio.ChatInterface(predict, title="Solution Specifier A", description="testing for the time being")
 
 
 
 
37
  chat.launch(share=True)
 
3
  from dotenv import load_dotenv
4
  load_dotenv()
5
 
6
+ import gradio as gr
 
7
 
8
+ from langchain.agents.openai_assistant import OpenAIAssistantRunnable
9
  from langchain.schema import HumanMessage, AIMessage
10
 
 
 
11
  api_key = os.getenv('OPENAI_API_KEY')
12
  extractor_agent = os.getenv('ASSISTANT_ID_SOLUTION_SPECIFIER_A')
13
 
14
+ # Create the assistant. By default, we don't specify a thread_id,
15
+ # so the first call that doesn't pass one will create a new thread.
16
+ extractor_llm = OpenAIAssistantRunnable(
17
+ assistant_id=extractor_agent,
18
+ api_key=api_key,
19
+ as_agent=True
20
+ )
21
+
22
+ # We will store thread_id globally or in a session variable.
23
+ THREAD_ID = None
24
 
25
  def remove_citation(text):
 
26
  pattern = r"【\d+†\w+】"
 
27
  return re.sub(pattern, "📚", text)
28
 
29
  def predict(message, history):
30
+ """
31
+ Receives the new user message plus the entire conversation history
32
+ from Gradio. If no thread_id is set, we create a new thread.
33
+ Otherwise we pass the existing thread_id.
34
+ """
35
+ global THREAD_ID
36
+
37
+ # 1) Decide if we are creating a new thread or continuing the old one
38
+ if THREAD_ID is None:
39
+ # No thread_id yet -> this is the first user message
40
+ response = extractor_llm.invoke({"content": message})
41
+ THREAD_ID = response.thread_id # store for subsequent calls
42
+ else:
43
+ # We already have a thread_id -> continue that same thread
44
+ response = extractor_llm.invoke({"content": message, "thread_id": THREAD_ID})
45
+
46
+ # 2) Extract the text output from the response
47
+ output = response.return_values["output"]
48
  non_cited_output = remove_citation(output)
49
+
50
+ # 3) Return the model's text to display in Gradio
51
  return non_cited_output
52
 
53
+ # Create a Gradio ChatInterface using our predict function
54
+ chat = gr.ChatInterface(
55
+ fn=predict,
56
+ title="Solution Specifier A",
57
+ description="Testing threaded conversation"
58
+ )
59
  chat.launch(share=True)