AbenzaFran commited on
Commit
eed3e9c
Β·
verified Β·
1 Parent(s): 6230219

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -56
app.py CHANGED
@@ -1,76 +1,67 @@
 
1
  import os
2
  import re
3
- import streamlit as st
4
  from dotenv import load_dotenv
 
 
 
 
5
  from langchain.agents.openai_assistant import OpenAIAssistantRunnable
 
6
 
7
- # Load environment variables
8
- load_dotenv()
9
- api_key = os.getenv("OPENAI_API_KEY")
10
- extractor_agent = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A")
11
 
12
- # Create the assistant
 
13
  extractor_llm = OpenAIAssistantRunnable(
14
  assistant_id=extractor_agent,
15
  api_key=api_key,
16
  as_agent=True
17
  )
18
 
19
- def remove_citation(text: str) -> str:
 
 
 
20
  pattern = r"【\d+†\w+】"
21
  return re.sub(pattern, "πŸ“š", text)
22
 
23
- # Initialize session state for messages and thread_id
24
- if "messages" not in st.session_state:
25
- st.session_state["messages"] = []
26
- if "thread_id" not in st.session_state:
27
- st.session_state["thread_id"] = None
28
-
29
- st.title("Solution Specifier A")
30
-
31
- def predict(user_input: str) -> str:
32
  """
33
- This function calls our OpenAIAssistantRunnable to get a response.
34
- If we don't have a thread_id yet, we create a new thread on the first call.
35
- Otherwise, we continue the existing thread.
36
  """
37
- if st.session_state["thread_id"] is None:
38
- response = extractor_llm.invoke({"content": user_input})
39
- st.session_state["thread_id"] = response.thread_id
40
- else:
41
- response = extractor_llm.invoke(
42
- {"content": user_input, "thread_id": st.session_state["thread_id"]}
43
- )
44
- output = response.return_values["output"]
45
- return remove_citation(output)
46
-
47
- # Display any existing messages (from a previous run or refresh)
48
- for msg in st.session_state["messages"]:
49
- if msg["role"] == "user":
50
- with st.chat_message("user"):
51
- st.write(msg["content"])
52
- else:
53
- with st.chat_message("assistant"):
54
- st.write(msg["content"])
55
-
56
- # Create the chat input widget at the bottom of the page
57
- user_input = st.chat_input("Type your message here...")
58
-
59
- # When the user hits ENTER on st.chat_input
60
- if user_input:
61
- # Add the user message to session state
62
- st.session_state["messages"].append({"role": "user", "content": user_input})
63
 
64
- # Display the user's message
65
- with st.chat_message("user"):
66
- st.write(user_input)
67
-
68
- # Get the assistant's response
69
- response_text = predict(user_input)
70
 
71
- # Add the assistant response to session state
72
- st.session_state["messages"].append({"role": "assistant", "content": response_text})
 
 
 
 
 
 
 
 
 
 
73
 
74
- # Display the assistant's reply
75
- with st.chat_message("assistant"):
76
- st.write(response_text)
 
 
 
 
 
 
 
 
1
+
2
  import os
3
  import re
 
4
  from dotenv import load_dotenv
5
+ load_dotenv()
6
+
7
+ import gradio as gr
8
+
9
  from langchain.agents.openai_assistant import OpenAIAssistantRunnable
10
+ from langchain.schema import HumanMessage, AIMessage
11
 
12
+ api_key = os.getenv('OPENAI_API_KEY')
13
+ extractor_agent = os.getenv('ASSISTANT_ID_SOLUTION_SPECIFIER_A')
 
 
14
 
15
+ # Create the assistant. By default, we don't specify a thread_id,
16
+ # so the first call that doesn't pass one will create a new thread.
17
  extractor_llm = OpenAIAssistantRunnable(
18
  assistant_id=extractor_agent,
19
  api_key=api_key,
20
  as_agent=True
21
  )
22
 
23
+ # We will store thread_id globally or in a session variable.
24
+ THREAD_ID = None
25
+
26
+ def remove_citation(text):
27
  pattern = r"【\d+†\w+】"
28
  return re.sub(pattern, "πŸ“š", text)
29
 
30
+ def predict(message, history):
 
 
 
 
 
 
 
 
31
  """
32
+ Receives the new user message plus the entire conversation history
33
+ from Gradio. If no thread_id is set, we create a new thread.
34
+ Otherwise we pass the existing thread_id.
35
  """
36
+ global THREAD_ID
37
+
38
+ # debug print
39
+ print("current history:", history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # If history is empty, this means that it is probably a new conversation and therefore the thread shall be reset
42
+ if not history:
43
+ THREAD_ID = None
 
 
 
44
 
45
+ # 1) Decide if we are creating a new thread or continuing the old one
46
+ if THREAD_ID is None:
47
+ # No thread_id yet -> this is the first user message
48
+ response = extractor_llm.invoke({"content": message})
49
+ THREAD_ID = response.thread_id # store for subsequent calls
50
+ else:
51
+ # We already have a thread_id -> continue that same thread
52
+ response = extractor_llm.invoke({"content": message, "thread_id": THREAD_ID})
53
+
54
+ # 2) Extract the text output from the response
55
+ output = response.return_values["output"]
56
+ non_cited_output = remove_citation(output)
57
 
58
+ # 3) Return the model's text to display in Gradio
59
+ return non_cited_output
60
+
61
+ # Create a Gradio ChatInterface using our predict function
62
+ chat = gr.ChatInterface(
63
+ fn=predict,
64
+ title="Solution Specifier A",
65
+ #description="Testing threaded conversation"
66
+ )
67
+ chat.launch(share=True)